Implement new Hub Protocol (Part Deux) (#390)

* convert to new protocol
* removed InvocationDescriptorRegistry because we're not yet sure about custom protocols
* update SocialWeather sample
* Moving ts client to using new protocol
* make the functional tests a little easier to run on ctrl-f5
This commit is contained in:
Andrew Stanton-Nurse 2017-05-09 12:24:58 -07:00 committed by GitHub
parent 6cf6feed64
commit 991c1d8517
66 changed files with 1957 additions and 1951 deletions

View File

@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.26228.9
VisualStudioVersion = 15.0.26411.1
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject
@ -74,6 +74,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "client-ts", "client-ts", "{
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Microbenchmarks", "test\Microsoft.AspNetCore.SignalR.Microbenchmarks\Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj", "{96771B3F-4D18-41A7-A75B-FF38E76AAC89}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -180,6 +182,10 @@ Global
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Debug|Any CPU.Build.0 = Debug|Any CPU
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.ActiveCfg = Release|Any CPU
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.Build.0 = Release|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -211,5 +217,6 @@ Global
{333526A4-633B-491A-AC45-CC62A0012D1C} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B}
{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
EndGlobalSection
EndGlobal

View File

@ -1,6 +1,7 @@
<Project>
<Project>
<PropertyGroup>
<AspNetCoreIntegrationTestingVersion>0.4.0-*</AspNetCoreIntegrationTestingVersion>
<SystemMemoryVersion>4.4.0-*</SystemMemoryVersion>
<AspNetCoreVersion>2.0.0-*</AspNetCoreVersion>
<CoreFxLabsVersion>0.1.0-*</CoreFxLabsVersion>
<CoreFxVersion>4.3.0</CoreFxVersion>

View File

@ -4,7 +4,7 @@ import { DataReceived, ConnectionClosed } from "../Microsoft.AspNetCore.SignalR.
import { TransportType, ITransport } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports"
describe("HubConnection", () => {
it("completes pending invocations when stopped", async (done) => {
it("completes pending invocations when stopped", async done => {
let connection: IConnection = {
start(transportType: TransportType | ITransport): Promise<void> {
return Promise.resolve();
@ -27,18 +27,18 @@ describe("HubConnection", () => {
let hubConnection = new HubConnection(connection);
var invokePromise = hubConnection.invoke("testMethod");
hubConnection.stop();
invokePromise
.then(() => {
fail();
done();
})
.catch((error: Error) => {
expect(error.message).toBe("Invocation cancelled due to connection being closed.");
done();
});
try {
await invokePromise;
fail();
}
catch (e) {
expect(e.message).toBe("Invocation cancelled due to connection being closed.");
}
done();
});
it("completes pending invocations when connection is lost", async (done) => {
it("completes pending invocations when connection is lost", async done => {
let connection: IConnection = {
start(transportType: TransportType | ITransport): Promise<void> {
return Promise.resolve();
@ -60,17 +60,92 @@ describe("HubConnection", () => {
let hubConnection = new HubConnection(connection);
var invokePromise = hubConnection.invoke("testMethod");
invokePromise
.then(() => {
fail();
done();
})
.catch((error: Error) => {
expect(error.message).toBe("Connection lost");
done();
});
// Typically this would be called by the transport
connection.onClosed(new Error("Connection lost"));
try {
await invokePromise;
fail();
}
catch (e) {
expect(e.message).toBe("Connection lost");
}
done();
});
it("sends invocations as nonblocking", async done => {
let dataSent: string;
let connection: IConnection = {
start(transportType: TransportType): Promise<void> {
return Promise.resolve();
},
send(data: any): Promise<void> {
dataSent = data;
return Promise.resolve();
},
stop(): void {
if (this.onClosed) {
this.onClosed();
}
},
onDataReceived: null,
onClosed: null
};
let hubConnection = new HubConnection(connection);
let invokePromise = hubConnection.invoke("testMethod");
expect(JSON.parse(dataSent).nonblocking).toBe(false);
// will clean pending promises
connection.onClosed();
try {
await invokePromise;
fail(); // exception is expected because the call has not completed
}
catch (e) {
}
done();
});
it("rejects streaming responses", async done => {
let connection: IConnection = {
start(transportType: TransportType): Promise<void> {
return Promise.resolve();
},
send(data: any): Promise<void> {
return Promise.resolve();
},
stop(): void {
if (this.onClosed) {
this.onClosed();
}
},
onDataReceived: null,
onClosed: null
};
let hubConnection = new HubConnection(connection);
let invokePromise = hubConnection.invoke("testMethod");
connection.onDataReceived("{ \"type\": 2, \"invocationId\": \"0\", \"result\": null }");
connection.onClosed();
try {
await invokePromise;
fail();
}
catch (e) {
expect(e.message).toBe("Streaming is not supported.");
}
done();
});
});

View File

@ -3,16 +3,31 @@ import { IConnection } from "./IConnection"
import { Connection } from "./Connection"
import { TransportType } from "./Transports"
interface InvocationDescriptor {
readonly Id: string;
readonly Method: string;
readonly Arguments: Array<any>;
const enum MessageType {
Invocation = 1,
Result,
Completion
}
interface InvocationResultDescriptor {
readonly Id: string;
readonly Error: string;
readonly Result: any;
interface HubMessage {
readonly type: MessageType;
readonly invocationId: string;
}
interface InvocationMessage extends HubMessage {
readonly target: string;
readonly arguments: Array<any>;
readonly nonblocking?: boolean;
}
interface ResultMessage extends HubMessage {
readonly result?: any;
}
interface CompletionMessage extends HubMessage {
readonly error?: string;
readonly result?: any;
}
export { Connection } from "./Connection"
@ -20,7 +35,7 @@ export { TransportType } from "./Transports"
export class HubConnection {
private connection: IConnection;
private callbacks: Map<string, (invocationDescriptor: InvocationResultDescriptor) => void>;
private callbacks: Map<string, (invocationUpdate: CompletionMessage|ResultMessage) => void>;
private methods: Map<string, (...args: any[]) => void>;
private id: number;
private connectionClosedCallback: ConnectionClosed;
@ -40,7 +55,7 @@ export class HubConnection {
this.onConnectionClosed(error);
}
this.callbacks = new Map<string, (invocationDescriptor: InvocationResultDescriptor) => void>();
this.callbacks = new Map<string, (invocationEvent: CompletionMessage|ResultMessage) => void>();
this.methods = new Map<string, (...args: any[]) => void>();
this.id = 0;
}
@ -51,34 +66,49 @@ export class HubConnection {
if (!data) {
return;
}
var descriptor = JSON.parse(data);
if (descriptor.Method === undefined) {
let invocationResult: InvocationResultDescriptor = descriptor;
let callback = this.callbacks.get(invocationResult.Id);
if (callback != null) {
callback(invocationResult);
this.callbacks.delete(invocationResult.Id);
var message = JSON.parse(data);
switch (message.type) {
case MessageType.Invocation:
this.InvokeClientMethod(<InvocationMessage>message);
break;
case MessageType.Result:
// TODO: Streaming (MessageType.Result) currently not supported - callback will throw
case MessageType.Completion:
let callback = this.callbacks.get(message.invocationId);
if (callback != null) {
callback(message);
this.callbacks.delete(message.invocationId);
}
break;
default:
console.log("Invalid message type: " + data);
break;
}
}
private InvokeClientMethod(invocationMessage: InvocationMessage) {
let method = this.methods.get(invocationMessage.target);
if (method) {
method.apply(this, invocationMessage.arguments);
if (!invocationMessage.nonblocking) {
// TODO: send result back to the server?
}
}
else {
let invocation: InvocationDescriptor = descriptor;
let method = this.methods[invocation.Method];
if (method != null) {
// TODO: bind? args?
method.apply(this, invocation.Arguments);
}
console.log(`No client method with the name '${invocationMessage.target}' found.`);
}
}
private onConnectionClosed(error: Error) {
let errorInvocationResult = {
Id: "-1",
Error: error ? error.message : "Invocation cancelled due to connection being closed.",
Result: null
} as InvocationResultDescriptor;
let errorCompletionMessage = <CompletionMessage>{
type: MessageType.Completion,
invocationId: "-1",
error: error ? error.message : "Invocation cancelled due to connection being closed.",
};
this.callbacks.forEach(callback => {
callback(errorInvocationResult);
callback(errorCompletionMessage);
});
this.callbacks.clear();
@ -99,19 +129,27 @@ export class HubConnection {
let id = this.id;
this.id++;
let invocationDescriptor: InvocationDescriptor = {
"Id": id.toString(),
"Method": methodName,
"Arguments": args
let invocationDescriptor: InvocationMessage = {
type: MessageType.Invocation,
invocationId: id.toString(),
target: methodName,
arguments: args,
nonblocking: false
};
let p = new Promise<any>((resolve, reject) => {
this.callbacks.set(invocationDescriptor.Id, (invocationResult: InvocationResultDescriptor) => {
if (invocationResult.Error != null) {
reject(new Error(invocationResult.Error));
this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: CompletionMessage | ResultMessage) => {
if (invocationEvent.type === MessageType.Completion) {
let completionMessage = <CompletionMessage>invocationEvent;
if (completionMessage.error) {
reject(new Error(completionMessage.error));
}
else {
resolve(completionMessage.result);
}
}
else {
resolve(invocationResult.Result);
reject(new Error("Streaming is not supported."))
}
});
@ -119,7 +157,7 @@ export class HubConnection {
this.connection.send(JSON.stringify(invocationDescriptor))
.catch(e => {
reject(e);
this.callbacks.delete(invocationDescriptor.Id);
this.callbacks.delete(invocationDescriptor.invocationId);
});
});
@ -127,7 +165,7 @@ export class HubConnection {
}
on(methodName: string, method: (...args: any[]) => void) {
this.methods[methodName] = method;
this.methods.set(methodName, method);
}
set onClosed(callback: ConnectionClosed) {

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Builder;
@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
app.UseDeveloperExceptionPage();
}
app.UseStaticFiles();
app.UseFileServer();
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("/echo"));
app.UseSignalR(routes =>
{

View File

@ -0,0 +1,13 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>SignalR Tests</title>
</head>
<body>
<h1>SignalR Tests</h1>
<ul>
<li><a href="connectionTests.html">Connection Tests</a></li>
</ul>
</body>
</html>

View File

@ -5,7 +5,6 @@ using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.Extensions.CommandLineUtils;
using Microsoft.Extensions.Logging;
@ -33,7 +32,7 @@ namespace ClientSample
var loggerFactory = new LoggerFactory();
Console.WriteLine("Connecting to {0}", baseUrl);
var connection = new HubConnection(new Uri(baseUrl), new JsonNetInvocationAdapter(), loggerFactory);
var connection = new HubConnection(new Uri(baseUrl), loggerFactory);
try
{
await connection.StartAsync();

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -22,6 +22,7 @@ namespace SocialWeather
public void OnConnectedAsync(Connection connection)
{
connection.Metadata[ConnectionMetadataNames.Format] = "json";
_connectionList.Add(connection);
}
@ -34,11 +35,11 @@ namespace SocialWeather
{
foreach (var connection in _connectionList)
{
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType"));
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>(ConnectionMetadataNames.Format));
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
var context = (HttpContext)connection.Metadata[typeof(HttpContext)];
var context = (HttpContext)connection.Metadata[ConnectionMetadataNames.HttpContext];
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? MessageType.Binary

View File

@ -17,7 +17,7 @@ namespace SocketsSample.EndPoints
{
Connections.Add(connection);
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
try
{
@ -37,7 +37,7 @@ namespace SocketsSample.EndPoints
{
Connections.Remove(connection);
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
}
}

View File

@ -1,91 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
namespace SocketsSample
{
public class LineInvocationAdapter : IInvocationAdapter
{
public async Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
{
var streamReader = new StreamReader(stream);
var line = await streamReader.ReadLineAsync();
if (line == null)
{
return null;
}
var values = line.Split(',');
var type = values[0].Substring(0, 2);
var id = values[0].Substring(2);
if (type.Equals("RI"))
{
var resultType = values[1].Substring(0, 1);
var result = values[1].Substring(1);
return new InvocationResultDescriptor()
{
Id = id,
Result = resultType.Equals("E") ? null : result,
Error = resultType.Equals("E") ? result : null,
};
}
else
{
var method = values[1].Substring(1);
return new InvocationDescriptor
{
Id = id,
Method = method,
Arguments = values.Skip(2).Zip(binder.GetParameterTypes(method), (v, t) => Convert.ChangeType(v, t)).ToArray()
};
}
}
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
{
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
return WriteInvocationDescriptorAsync(invocationDescriptor, stream);
}
else
{
return WriteInvocationResultAsync((InvocationResultDescriptor)message, stream);
}
}
private Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
{
var msg = $"CI{invocationDescriptor.Id},M{invocationDescriptor.Method},{string.Join(",", invocationDescriptor.Arguments.Select(a => a.ToString()))}\n";
return WriteAsync(msg, stream);
}
private Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
{
if (string.IsNullOrEmpty(resultDescriptor.Error))
{
return WriteAsync($"RI{resultDescriptor.Id},E{resultDescriptor.Error}\n", stream);
}
else
{
return WriteAsync($"RI{resultDescriptor.Id},R{(resultDescriptor.Result != null ? resultDescriptor.Result.ToString() : string.Empty)}\n", stream);
}
}
private async Task WriteAsync(string msg, Stream stream)
{
var writer = new StreamWriter(stream);
await writer.WriteAsync(msg);
await writer.FlushAsync();
}
}
}

View File

@ -1,168 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
namespace SocketsSample.Protobuf
{
public class ProtobufInvocationAdapter : IInvocationAdapter
{
private IServiceProvider _serviceProvider;
public ProtobufInvocationAdapter(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}
public Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
{
return Task.Run(() => CreateInvocationMessageInt(stream, binder));
}
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
private Task<InvocationMessage> CreateInvocationMessageInt(Stream stream, IInvocationBinder binder)
{
var inputStream = new CodedInputStream(stream, leaveOpen: true);
var messageKind = new RpcMessageKind();
inputStream.ReadMessage(messageKind);
if(messageKind.MessageKind == RpcMessageKind.Types.Kind.Invocation)
{
return CreateInvocationDescriptorInt(inputStream, binder);
}
else
{
return CreateInvocationResultDescriptorInt(inputStream, binder);
}
}
private Task<InvocationMessage> CreateInvocationResultDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
{
throw new NotImplementedException("Not yet implemented for Protobuf");
}
private Task<InvocationMessage> CreateInvocationDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
{
var invocationHeader = new RpcInvocationHeader();
inputStream.ReadMessage(invocationHeader);
var argumentTypes = binder.GetParameterTypes(invocationHeader.Name);
var invocationDescriptor = new InvocationDescriptor();
invocationDescriptor.Method = invocationHeader.Name;
invocationDescriptor.Id = invocationHeader.Id.ToString();
invocationDescriptor.Arguments = new object[argumentTypes.Length];
var primitiveParser = PrimitiveValue.Parser;
for (var i = 0; i < argumentTypes.Length; i++)
{
if (typeof(int) == argumentTypes[i])
{
var value = new PrimitiveValue();
inputStream.ReadMessage(value);
invocationDescriptor.Arguments[i] = value.Int32Value;
}
else if (typeof(string) == argumentTypes[i])
{
var value = new PrimitiveValue();
inputStream.ReadMessage(value);
invocationDescriptor.Arguments[i] = value.StringValue;
}
else
{
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
invocationDescriptor.Arguments[i] = serializer.GetValue(inputStream, argumentTypes[i]);
}
}
return Task.FromResult<InvocationMessage>(invocationDescriptor);
}
public async Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
{
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result });
var resultHeader = new RpcInvocationResultHeader
{
Id = int.Parse(resultDescriptor.Id),
HasResult = resultDescriptor.Result != null
};
if (resultDescriptor.Error != null)
{
resultHeader.Error = resultDescriptor.Error;
}
outputStream.WriteMessage(resultHeader);
if (string.IsNullOrEmpty(resultHeader.Error) && resultDescriptor.Result != null)
{
var result = resultDescriptor.Result;
if (result.GetType() == typeof(int))
{
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result });
}
else if (result.GetType() == typeof(string))
{
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result });
}
else
{
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
var message = serializer.GetMessage(result);
outputStream.WriteMessage(message);
}
}
outputStream.Flush();
await stream.FlushAsync();
}
public async Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
{
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation });
var invocationHeader = new RpcInvocationHeader()
{
Id = 0,
Name = invocationDescriptor.Method,
NumArgs = invocationDescriptor.Arguments.Length
};
outputStream.WriteMessage(invocationHeader);
foreach (var arg in invocationDescriptor.Arguments)
{
if (arg.GetType() == typeof(int))
{
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg });
}
else if (arg.GetType() == typeof(string))
{
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg });
}
else
{
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
var message = serializer.GetMessage(arg);
outputStream.WriteMessage(message);
}
}
outputStream.Flush();
await stream.FlushAsync();
}
}
}

View File

@ -1,848 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: RpcInvocation.proto
#pragma warning disable 1591, 0612, 3021
#region Designer generated code
using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
/// <summary>Holder for reflection information generated from RpcInvocation.proto</summary>
public static partial class RpcInvocationReflection {
#region Descriptor
/// <summary>File descriptor for RpcInvocation.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;
static RpcInvocationReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l",
"c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k",
"EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u",
"SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD",
"IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF",
"EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp",
"dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY",
"AiABKAlIAEIICgZvbmVvZl8iKgoNUGVyc29uTWVzc2FnZRIMCgROYW1lGAEg",
"ASgJEgsKA0FnZRgCIAEoA2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null),
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::PersonMessage), global::PersonMessage.Parser, new[]{ "Name", "Age" }, null, null, null)
}));
}
#endregion
}
#region Messages
public sealed partial class RpcMessageKind : pb::IMessage<RpcMessageKind> {
private static readonly pb::MessageParser<RpcMessageKind> _parser = new pb::MessageParser<RpcMessageKind>(() => new RpcMessageKind());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<RpcMessageKind> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind(RpcMessageKind other) : this() {
messageKind_ = other.messageKind_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind Clone() {
return new RpcMessageKind(this);
}
/// <summary>Field number for the "MessageKind" field.</summary>
public const int MessageKindFieldNumber = 1;
private global::RpcMessageKind.Types.Kind messageKind_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::RpcMessageKind.Types.Kind MessageKind {
get { return messageKind_; }
set {
messageKind_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as RpcMessageKind);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(RpcMessageKind other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (MessageKind != other.MessageKind) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (MessageKind != 0) hash ^= MessageKind.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (MessageKind != 0) {
output.WriteRawTag(8);
output.WriteEnum((int) MessageKind);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (MessageKind != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(RpcMessageKind other) {
if (other == null) {
return;
}
if (other.MessageKind != 0) {
MessageKind = other.MessageKind;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 8: {
messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum();
break;
}
}
}
}
#region Nested types
/// <summary>Container for nested types declared in the RpcMessageKind message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static partial class Types {
public enum Kind {
[pbr::OriginalName("Result")] Result = 0,
[pbr::OriginalName("Invocation")] Invocation = 1,
}
}
#endregion
}
public sealed partial class RpcInvocationHeader : pb::IMessage<RpcInvocationHeader> {
private static readonly pb::MessageParser<RpcInvocationHeader> _parser = new pb::MessageParser<RpcInvocationHeader>(() => new RpcInvocationHeader());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<RpcInvocationHeader> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationHeader() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationHeader(RpcInvocationHeader other) : this() {
name_ = other.name_;
id_ = other.id_;
numArgs_ = other.numArgs_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationHeader Clone() {
return new RpcInvocationHeader(this);
}
/// <summary>Field number for the "Name" field.</summary>
public const int NameFieldNumber = 1;
private string name_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Name {
get { return name_; }
set {
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
/// <summary>Field number for the "Id" field.</summary>
public const int IdFieldNumber = 2;
private int id_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Id {
get { return id_; }
set {
id_ = value;
}
}
/// <summary>Field number for the "NumArgs" field.</summary>
public const int NumArgsFieldNumber = 3;
private int numArgs_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int NumArgs {
get { return numArgs_; }
set {
numArgs_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as RpcInvocationHeader);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(RpcInvocationHeader other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Name != other.Name) return false;
if (Id != other.Id) return false;
if (NumArgs != other.NumArgs) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
if (Id != 0) hash ^= Id.GetHashCode();
if (NumArgs != 0) hash ^= NumArgs.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
if (Id != 0) {
output.WriteRawTag(16);
output.WriteInt32(Id);
}
if (NumArgs != 0) {
output.WriteRawTag(24);
output.WriteInt32(NumArgs);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
}
if (Id != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
if (NumArgs != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumArgs);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(RpcInvocationHeader other) {
if (other == null) {
return;
}
if (other.Name.Length != 0) {
Name = other.Name;
}
if (other.Id != 0) {
Id = other.Id;
}
if (other.NumArgs != 0) {
NumArgs = other.NumArgs;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 10: {
Name = input.ReadString();
break;
}
case 16: {
Id = input.ReadInt32();
break;
}
case 24: {
NumArgs = input.ReadInt32();
break;
}
}
}
}
}
public sealed partial class RpcInvocationResultHeader : pb::IMessage<RpcInvocationResultHeader> {
private static readonly pb::MessageParser<RpcInvocationResultHeader> _parser = new pb::MessageParser<RpcInvocationResultHeader>(() => new RpcInvocationResultHeader());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<RpcInvocationResultHeader> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() {
id_ = other.id_;
hasResult_ = other.hasResult_;
error_ = other.error_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader Clone() {
return new RpcInvocationResultHeader(this);
}
/// <summary>Field number for the "Id" field.</summary>
public const int IdFieldNumber = 1;
private int id_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Id {
get { return id_; }
set {
id_ = value;
}
}
/// <summary>Field number for the "HasResult" field.</summary>
public const int HasResultFieldNumber = 2;
private bool hasResult_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HasResult {
get { return hasResult_; }
set {
hasResult_ = value;
}
}
/// <summary>Field number for the "Error" field.</summary>
public const int ErrorFieldNumber = 3;
private string error_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Error {
get { return error_; }
set {
error_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as RpcInvocationResultHeader);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(RpcInvocationResultHeader other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Id != other.Id) return false;
if (HasResult != other.HasResult) return false;
if (Error != other.Error) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (Id != 0) hash ^= Id.GetHashCode();
if (HasResult != false) hash ^= HasResult.GetHashCode();
if (Error.Length != 0) hash ^= Error.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (Id != 0) {
output.WriteRawTag(8);
output.WriteInt32(Id);
}
if (HasResult != false) {
output.WriteRawTag(16);
output.WriteBool(HasResult);
}
if (Error.Length != 0) {
output.WriteRawTag(26);
output.WriteString(Error);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (Id != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
if (HasResult != false) {
size += 1 + 1;
}
if (Error.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Error);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(RpcInvocationResultHeader other) {
if (other == null) {
return;
}
if (other.Id != 0) {
Id = other.Id;
}
if (other.HasResult != false) {
HasResult = other.HasResult;
}
if (other.Error.Length != 0) {
Error = other.Error;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 8: {
Id = input.ReadInt32();
break;
}
case 16: {
HasResult = input.ReadBool();
break;
}
case 26: {
Error = input.ReadString();
break;
}
}
}
}
}
public sealed partial class PrimitiveValue : pb::IMessage<PrimitiveValue> {
private static readonly pb::MessageParser<PrimitiveValue> _parser = new pb::MessageParser<PrimitiveValue>(() => new PrimitiveValue());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<PrimitiveValue> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PrimitiveValue() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PrimitiveValue(PrimitiveValue other) : this() {
switch (other.OneofCase) {
case OneofOneofCase.Int32Value:
Int32Value = other.Int32Value;
break;
case OneofOneofCase.StringValue:
StringValue = other.StringValue;
break;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PrimitiveValue Clone() {
return new PrimitiveValue(this);
}
/// <summary>Field number for the "Int32Value" field.</summary>
public const int Int32ValueFieldNumber = 1;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Int32Value {
get { return oneofCase_ == OneofOneofCase.Int32Value ? (int) oneof_ : 0; }
set {
oneof_ = value;
oneofCase_ = OneofOneofCase.Int32Value;
}
}
/// <summary>Field number for the "StringValue" field.</summary>
public const int StringValueFieldNumber = 2;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string StringValue {
get { return oneofCase_ == OneofOneofCase.StringValue ? (string) oneof_ : ""; }
set {
oneof_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
oneofCase_ = OneofOneofCase.StringValue;
}
}
private object oneof_;
/// <summary>Enum of possible cases for the "oneof_" oneof.</summary>
public enum OneofOneofCase {
None = 0,
Int32Value = 1,
StringValue = 2,
}
private OneofOneofCase oneofCase_ = OneofOneofCase.None;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public OneofOneofCase OneofCase {
get { return oneofCase_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void ClearOneof() {
oneofCase_ = OneofOneofCase.None;
oneof_ = null;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as PrimitiveValue);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(PrimitiveValue other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Int32Value != other.Int32Value) return false;
if (StringValue != other.StringValue) return false;
if (OneofCase != other.OneofCase) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (oneofCase_ == OneofOneofCase.Int32Value) hash ^= Int32Value.GetHashCode();
if (oneofCase_ == OneofOneofCase.StringValue) hash ^= StringValue.GetHashCode();
hash ^= (int) oneofCase_;
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (oneofCase_ == OneofOneofCase.Int32Value) {
output.WriteRawTag(8);
output.WriteInt32(Int32Value);
}
if (oneofCase_ == OneofOneofCase.StringValue) {
output.WriteRawTag(18);
output.WriteString(StringValue);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (oneofCase_ == OneofOneofCase.Int32Value) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Int32Value);
}
if (oneofCase_ == OneofOneofCase.StringValue) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(StringValue);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(PrimitiveValue other) {
if (other == null) {
return;
}
switch (other.OneofCase) {
case OneofOneofCase.Int32Value:
Int32Value = other.Int32Value;
break;
case OneofOneofCase.StringValue:
StringValue = other.StringValue;
break;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 8: {
Int32Value = input.ReadInt32();
break;
}
case 18: {
StringValue = input.ReadString();
break;
}
}
}
}
}
public sealed partial class PersonMessage : pb::IMessage<PersonMessage> {
private static readonly pb::MessageParser<PersonMessage> _parser = new pb::MessageParser<PersonMessage>(() => new PersonMessage());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<PersonMessage> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[4]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PersonMessage() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PersonMessage(PersonMessage other) : this() {
name_ = other.name_;
age_ = other.age_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public PersonMessage Clone() {
return new PersonMessage(this);
}
/// <summary>Field number for the "Name" field.</summary>
public const int NameFieldNumber = 1;
private string name_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Name {
get { return name_; }
set {
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
/// <summary>Field number for the "Age" field.</summary>
public const int AgeFieldNumber = 2;
private long age_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public long Age {
get { return age_; }
set {
age_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as PersonMessage);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(PersonMessage other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Name != other.Name) return false;
if (Age != other.Age) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
if (Age != 0L) hash ^= Age.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
if (Age != 0L) {
output.WriteRawTag(16);
output.WriteInt64(Age);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
}
if (Age != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(Age);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(PersonMessage other) {
if (other == null) {
return;
}
if (other.Name.Length != 0) {
Name = other.Name;
}
if (other.Age != 0L) {
Age = other.Age;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 10: {
Name = input.ReadString();
break;
}
case 16: {
Age = input.ReadInt64();
break;
}
}
}
}
}
#endregion
#endregion Designer generated code

View File

@ -1,30 +0,0 @@
syntax = "proto3";
message RpcMessageKind {
enum Kind { Result = 0; Invocation = 1; }
Kind MessageKind = 1;
}
message RpcInvocationHeader {
string Name = 1;
int32 Id = 2;
int32 NumArgs = 3;
}
message RpcInvocationResultHeader {
int32 Id = 1;
bool HasResult = 2;
string Error = 3;
}
message PrimitiveValue {
oneof oneof_ {
int32 Int32Value = 1;
string StringValue = 2;
}
}
message PersonMessage {
string Name = 1;
int64 Age = 2;
}

View File

@ -1,36 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Google.Protobuf;
using SocketsSample.Hubs;
namespace SocketsSample
{
public class ProtobufSerializer
{
public object GetValue(CodedInputStream inputStream, Type type)
{
if (type == typeof(Person))
{
var value = new PersonMessage();
inputStream.ReadMessage(value);
return new Person { Name = value.Name, Age = value.Age };
}
throw new InvalidOperationException("(Deserialize) Unknown type.");
}
public IMessage GetMessage(object value)
{
Person person = value as Person;
if (person != null)
{
return new PersonMessage { Name = person.Name, Age = person.Age };
}
throw new InvalidOperationException("(Serialize) Unknown type.");
}
}
}

View File

@ -3,12 +3,10 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints;
using SocketsSample.Hubs;
using SocketsSample.Protobuf;
namespace SocketsSample
{
@ -18,21 +16,12 @@ namespace SocketsSample
// For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
{
services.AddSingleton<ProtobufInvocationAdapter>();
services.AddSingleton<LineInvocationAdapter>();
services.AddSockets();
services.AddSignalR(options =>
{
options.RegisterInvocationAdapter<ProtobufInvocationAdapter>("protobuf");
options.RegisterInvocationAdapter<LineInvocationAdapter>("line");
});
services.AddSignalR();
// .AddRedis();
services.AddEndPoint<MessagesEndPoint>();
services.AddSingleton<ProtobufSerializer>();
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.

View File

@ -1,19 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client
{
@ -22,17 +23,14 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private readonly IConnection _connection;
private readonly IInvocationAdapter _adapter;
private readonly IHubProtocol _protocol;
private readonly HubBinder _binder;
private HttpClient _httpClient;
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
// We need to ensure pending calls added after a connection failure don't hang. Right now the easiest thing to do is lock.
private readonly object _pendingCallsLock = new object();
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
private readonly Dictionary<string, InvocationRequest> _pendingCalls = new Dictionary<string, InvocationRequest>();
private readonly ConcurrentDictionary<string, InvocationHandler> _handlers = new ConcurrentDictionary<string, InvocationHandler>();
private int _nextId = 0;
@ -50,27 +48,33 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
public HubConnection(Uri url)
: this(new Connection(url), new JsonNetInvocationAdapter(), null)
: this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), null)
{ }
public HubConnection(Uri url, ILoggerFactory loggerFactory)
: this(new Connection(url), new JsonNetInvocationAdapter(), loggerFactory)
: this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), loggerFactory)
{ }
public HubConnection(Uri url, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
: this(new Connection(url, loggerFactory), adapter, loggerFactory)
// These are only really needed for tests now...
public HubConnection(IConnection connection, ILoggerFactory loggerFactory)
: this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory)
{ }
public HubConnection(IConnection connection, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
if (protocol == null)
{
throw new ArgumentNullException(nameof(protocol));
}
_connection = connection;
_binder = new HubBinder(this);
_adapter = adapter;
_protocol = protocol;
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<HubConnection>();
_connection.Received += OnDataReceived;
@ -100,6 +104,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public async Task DisposeAsync()
{
await _connection.DisposeAsync();
_httpClient?.Dispose();
}
@ -118,81 +123,80 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Task<object> Invoke(string methodName, Type returnType, params object[] args) => Invoke(methodName, returnType, CancellationToken.None, args);
public async Task<object> Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
{
_logger.LogTrace("Preparing invocation of '{0}', with return type '{1}' and {2} args", methodName, returnType.AssemblyQualifiedName, args.Length);
ThrowIfConnectionTerminated();
_logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, returnType.AssemblyQualifiedName, args.Length);
// Create an invocation descriptor.
var descriptor = new InvocationDescriptor
{
Id = GetNextId(),
Method = methodName,
Arguments = args
};
// Create an invocation descriptor. Client invocations are always blocking
var invocationMessage = new InvocationMessage(GetNextId(), nonBlocking: false, target: methodName, arguments: args);
// I just want an excuse to use 'irq' as a variable name...
_logger.LogDebug("Registering Invocation ID '{0}' for tracking", descriptor.Id);
var irq = new InvocationRequest(cancellationToken, returnType);
_logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId);
var irq = new InvocationRequest(cancellationToken, returnType, invocationMessage.InvocationId, _loggerFactory);
lock (_pendingCallsLock)
{
if (_connectionActive.IsCancellationRequested)
{
throw new InvalidOperationException("Connection has been terminated.");
}
_pendingCalls.Add(descriptor.Id, irq);
}
AddInvocation(irq);
// Trace the invocation, but only if that logging level is enabled (because building the args list is a bit slow)
// Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow)
if (_logger.IsEnabled(LogLevel.Trace))
{
var argsList = string.Join(", ", args.Select(a => a.GetType().FullName));
_logger.LogTrace("Invocation #{0}: {1} {2}({3})", descriptor.Id, returnType.FullName, methodName, argsList);
_logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, returnType.FullName, methodName, argsList);
}
try
{
var ms = new MemoryStream();
await _adapter.WriteMessageAsync(descriptor, ms, cancellationToken);
var payload = await _protocol.WriteToArrayAsync(invocationMessage);
_logger.LogInformation("Sending Invocation #{0}", descriptor.Id);
_logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId);
// TODO: Format.Text - who, where and when decides about the format of outgoing messages
await _connection.SendAsync(ms.ToArray(), MessageType.Text, cancellationToken);
_logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id);
await _connection.SendAsync(payload, _protocol.MessageType, cancellationToken);
_logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Sending Invocation #{0} failed", descriptor.Id);
irq.Completion.TrySetException(ex);
lock (_pendingCallsLock)
{
_pendingCalls.Remove(descriptor.Id);
}
_logger.LogError(0, ex, "Sending Invocation '{invocationId}' failed", invocationMessage.InvocationId);
irq.Fail(ex);
TryRemoveInvocation(invocationMessage.InvocationId, out _);
}
// Return the completion task. It will be completed by ReceiveMessages when the response is received.
return await irq.Completion.Task;
return await irq.Completion;
}
private async void OnDataReceived(byte[] data, MessageType messageType)
private void OnDataReceived(byte[] data, MessageType messageType)
{
var message
= await _adapter.ReadMessageAsync(new MemoryStream(data), _binder, _connectionActive.Token);
var message = _protocol.ParseMessage(data, _binder);
InvocationRequest irq;
switch (message)
{
case InvocationDescriptor invocationDescriptor:
DispatchInvocation(invocationDescriptor, _connectionActive.Token);
break;
case InvocationResultDescriptor invocationResultDescriptor:
InvocationRequest irq;
lock (_pendingCallsLock)
case InvocationMessage invocation:
if (_logger.IsEnabled(LogLevel.Trace))
{
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName));
_logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList);
}
DispatchInvocationResult(invocationResultDescriptor, irq, _connectionActive.Token);
DispatchInvocation(invocation, _connectionActive.Token);
break;
case CompletionMessage completion:
if (!TryRemoveInvocation(completion.InvocationId, out irq))
{
_logger.LogWarning("Dropped unsolicited Completion message for invocation '{invocationId}'", completion.InvocationId);
return;
}
DispatchInvocationCompletion(completion, irq);
irq.Dispose();
break;
case StreamItemMessage streamItem:
// Complete the invocation with an error, we don't support streaming (yet)
if (!TryRemoveInvocation(streamItem.InvocationId, out irq))
{
_logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId);
return;
}
irq.Fail(new NotSupportedException("Streaming method results are not supported"));
break;
default:
throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}");
}
}
@ -201,74 +205,118 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogTrace("Shutting down connection");
if (ex != null)
{
_logger.LogError("Connection is shutting down due to an error: {0}", ex);
_logger.LogError(ex, "Connection is shutting down due to an error");
}
lock (_pendingCallsLock)
{
// We cancel inside the lock to make sure everyone who was part-way through registering an invocation
// completes. This also ensures that nobody will add things to _pendingCalls after we leave this block
// because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock)
_connectionActive.Cancel();
foreach (var call in _pendingCalls.Values)
foreach (var outstandingCall in _pendingCalls.Values)
{
_logger.LogTrace("Removing pending call {invocationId}", outstandingCall.InvocationId);
if (ex != null)
{
call.Completion.TrySetException(ex);
}
else
{
call.Completion.TrySetCanceled();
outstandingCall.Fail(ex);
}
outstandingCall.Dispose();
}
_pendingCalls.Clear();
}
}
private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken)
private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken)
{
// Find the handler
if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler))
if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler))
{
_logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method);
_logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target);
return;
}
// TODO: Return values
// TODO: Dispatch to a sync context to ensure we aren't blocking this loop.
handler.Handler(invocationDescriptor.Arguments);
handler.Handler(invocation.Arguments);
}
private void DispatchInvocationResult(InvocationResultDescriptor result, InvocationRequest irq, CancellationToken cancellationToken)
private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq)
{
_logger.LogInformation("Received Result for Invocation #{0}", result.Id);
_logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId);
if (cancellationToken.IsCancellationRequested)
if (irq.CancellationToken.IsCancellationRequested)
{
return;
_logger.LogTrace("Cancelling dispatch of Completion message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId);
}
Debug.Assert(irq.Completion != null, "Didn't properly capture InvocationRequest in callback for ReadInvocationResultDescriptorAsync");
// If the invocation hasn't been cancelled, dispatch the result
if (!irq.CancellationToken.IsCancellationRequested)
else
{
irq.Registration.Dispose();
// Complete the request based on the result
// TODO: the TrySetXYZ methods will cause continuations attached to the Task to run, so we should dispatch to a sync context or thread pool.
if (!string.IsNullOrEmpty(result.Error))
if (!string.IsNullOrEmpty(completion.Error))
{
_logger.LogInformation("Completing Invocation #{0} with error: {1}", result.Id, result.Error);
irq.Completion.TrySetException(new Exception(result.Error));
irq.Fail(new HubException(completion.Error));
}
else
{
_logger.LogInformation("Completing Invocation #{0} with result of type: {1}", result.Id, result.Result?.GetType()?.FullName ?? "<<void>>");
irq.Completion.TrySetResult(result.Result);
irq.Complete(completion.Result);
}
}
}
private void ThrowIfConnectionTerminated()
{
if (_connectionActive.Token.IsCancellationRequested)
{
_logger.LogError("Invoke was called after the connection was terminated");
throw new InvalidOperationException("Connection has been terminated.");
}
}
private string GetNextId() => Interlocked.Increment(ref _nextId).ToString();
private void AddInvocation(InvocationRequest irq)
{
lock (_pendingCallsLock)
{
ThrowIfConnectionTerminated();
if (_pendingCalls.ContainsKey(irq.InvocationId))
{
_logger.LogCritical("Invocation ID '{invocationId}' is already in use.", irq.InvocationId);
throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use.");
}
else
{
_pendingCalls.Add(irq.InvocationId, irq);
}
}
}
private bool TryGetInvocation(string invocationId, out InvocationRequest irq)
{
lock (_pendingCallsLock)
{
ThrowIfConnectionTerminated();
return _pendingCalls.TryGetValue(invocationId, out irq);
}
}
private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq)
{
lock (_pendingCallsLock)
{
ThrowIfConnectionTerminated();
if (_pendingCalls.TryGetValue(invocationId, out irq))
{
_pendingCalls.Remove(invocationId);
return true;
}
else
{
return false;
}
}
}
private class HubBinder : IInvocationBinder
{
private HubConnection _connection;
@ -282,7 +330,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq))
{
_connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId);
_connection._logger.LogError("Unsolicited response received for invocation '{invocationId}'", invocationId);
return null;
}
return irq.ResultType;
@ -292,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler))
{
_connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName);
_connection._logger.LogWarning("Failed to find handler for '{target}' method", methodName);
return Type.EmptyTypes;
}
return handler.ParameterTypes;
@ -311,20 +359,51 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private struct InvocationRequest
private class InvocationRequest : IDisposable
{
private readonly TaskCompletionSource<object> _completionSource = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly CancellationTokenRegistration _cancellationTokenRegistration;
private readonly ILogger _logger;
public Type ResultType { get; }
public CancellationToken CancellationToken { get; }
public CancellationTokenRegistration Registration { get; }
public TaskCompletionSource<object> Completion { get; }
public string InvocationId { get; }
public InvocationRequest(CancellationToken cancellationToken, Type resultType)
public Task<object> Completion => _completionSource.Task;
public InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory)
{
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
Completion = tcs;
_logger = loggerFactory.CreateLogger<InvocationRequest>();
_cancellationTokenRegistration = cancellationToken.Register(() => _completionSource.TrySetCanceled());
InvocationId = invocationId;
CancellationToken = cancellationToken;
Registration = cancellationToken.Register(() => tcs.TrySetCanceled());
ResultType = resultType;
_logger.LogTrace("Invocation {invocationId} created", InvocationId);
}
public void Fail(Exception exception)
{
_logger.LogTrace("Invocation {invocationId} marked as failed", InvocationId);
_completionSource.TrySetException(exception);
}
public void Complete(object result)
{
_logger.LogTrace("Invocation {invocationId} marked as completed", InvocationId);
_completionSource.TrySetResult(result);
}
public void Dispose()
{
_logger.LogTrace("Invocation {invocationId} disposed", InvocationId);
// Just in case it hasn't already been completed
_completionSource.TrySetCanceled();
_cancellationTokenRegistration.Dispose();
}
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.SignalR.Client
{
[Serializable]
public class HubException : Exception
{
public HubException()
{
}
public HubException(string message) : base(message)
{
}
public HubException(string message, Exception innerException) : base(message, innerException)
{
}
}
}

View File

@ -1,16 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
public interface IInvocationAdapter
{
Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken);
Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken);
}
}

View File

@ -3,7 +3,7 @@
using System;
namespace Microsoft.AspNetCore.SignalR
namespace Microsoft.AspNetCore.SignalR.Internal
{
public interface IInvocationBinder
{

View File

@ -0,0 +1,38 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class CompletionMessage : HubMessage
{
public string Error { get; }
public object Result { get; }
public bool HasResult { get; }
public CompletionMessage(string invocationId, string error, object result, bool hasResult) : base(invocationId)
{
if (error != null && result != null)
{
throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both");
}
Error = error;
Result = result;
HasResult = hasResult;
}
public override string ToString()
{
var errorStr = Error == null ? "<<null>>" : $"\"{Error}\"";
var resultField = HasResult ? $", {nameof(Result)}: {Result ?? "<<null>>"}" : string.Empty;
return $"Completion {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Error)}: {errorStr}{resultField} }}";
}
// Static factory methods. Don't want to use constructor overloading because it will break down
// if you need to send a payload statically-typed as a string. And because a static factory is clearer here
public static CompletionMessage WithError(string invocationId, string error) => new CompletionMessage(invocationId, error, result: null, hasResult: false);
public static CompletionMessage WithResult(string invocationId, object payload) => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true);
}
}

View File

@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public abstract class HubMessage
{
public string InvocationId { get; }
protected HubMessage(string invocationId)
{
InvocationId = invocationId;
}
}
}

View File

@ -0,0 +1,37 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.IO.Pipelines;
using System.IO.Pipelines.Text.Primitives;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static class HubProtocolWriteMessageExtensions
{
public static async ValueTask<byte[]> WriteToArrayAsync(this IHubProtocol protocol, HubMessage message)
{
using (var memoryStream = new MemoryStream())
{
var pipe = memoryStream.AsPipelineWriter();
// See https://github.com/dotnet/corefxlab/issues/1460, the TextEncoder is unimportant but required.
var output = new PipelineTextOutput(pipe, TextEncoder.Utf8);
// Encode the message
if (!protocol.TryWriteMessage(message, output))
{
throw new InvalidOperationException("Failed to write message to the output stream");
}
await output.FlushAsync();
// Create a message
return memoryStream.ToArray();
}
}
}
}

View File

@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public interface IHubProtocol
{
MessageType MessageType { get; }
HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder);
bool TryWriteMessage(HubMessage message, IOutput output);
}
}

View File

@ -0,0 +1,44 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class InvocationMessage : HubMessage
{
public string Target { get; }
public object[] Arguments { get; }
public bool NonBlocking { get; }
public InvocationMessage(string invocationId, bool nonBlocking, string target, params object[] arguments) : base(invocationId)
{
if (string.IsNullOrEmpty(invocationId))
{
throw new ArgumentNullException(nameof(invocationId));
}
if (string.IsNullOrEmpty(target))
{
throw new ArgumentNullException(nameof(target));
}
if (arguments == null)
{
throw new ArgumentNullException(nameof(arguments));
}
Target = target;
Arguments = arguments;
NonBlocking = nonBlocking;
}
public override string ToString()
{
return $"Invocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(NonBlocking)}: {NonBlocking}, {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments.Select(a => a?.ToString()))} ] }}";
}
}
}

View File

@ -0,0 +1,281 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO;
using Microsoft.AspNetCore.Sockets;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class JsonHubProtocol : IHubProtocol
{
private const string ResultPropertyName = "result";
private const string InvocationIdPropertyName = "invocationId";
private const string TypePropertyName = "type";
private const string ErrorPropertyName = "error";
private const string TargetPropertyName = "target";
private const string NonBlockingPropertyName = "nonBlocking";
private const string ArgumentsPropertyName = "arguments";
private const int InvocationMessageType = 1;
private const int ResultMessageType = 2;
private const int CompletionMessageType = 3;
// ONLY to be used for application payloads (args, return values, etc.)
private JsonSerializer _payloadSerializer;
public MessageType MessageType => MessageType.Text;
/// <summary>
/// Creates an instance of the <see cref="JsonHubProtocol"/> using the specified <see cref="JsonSerializer"/>
/// to serialize application payloads (arguments, results, etc.). The serialization of the outer protocol can
/// NOT be changed using this serializer.
/// </summary>
/// <param name="payloadSerializer">The <see cref="JsonSerializer"/> to use to serialize application payloads (arguments, results, etc.).</param>
public JsonHubProtocol(JsonSerializer payloadSerializer)
{
if (payloadSerializer == null)
{
throw new ArgumentNullException(nameof(payloadSerializer));
}
_payloadSerializer = payloadSerializer;
}
public HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder)
{
// TODO: Need a span-native JSON parser!
using (var memoryStream = new MemoryStream(input.ToArray()))
{
return ParseMessage(memoryStream, binder);
}
}
public bool TryWriteMessage(HubMessage message, IOutput output)
{
// TODO: Need IOutput-compatible JSON serializer!
using (var memoryStream = new MemoryStream())
{
WriteMessage(message, memoryStream);
memoryStream.Flush();
return output.TryWrite(memoryStream.ToArray());
}
}
private HubMessage ParseMessage(Stream input, IInvocationBinder binder)
{
using (var reader = new JsonTextReader(new StreamReader(input)))
{
try
{
// PERF: Could probably use the JsonTextReader directly for better perf and fewer allocations
var token = JToken.ReadFrom(reader);
if (token == null)
{
return null;
}
if (token.Type != JTokenType.Object)
{
throw new FormatException($"Unexpected JSON Token Type '{token.Type}'. Expected a JSON Object.");
}
var json = (JObject)token;
// Determine the type of the message
var type = GetRequiredProperty<int>(json, TypePropertyName, JTokenType.Integer);
switch (type)
{
case InvocationMessageType:
return BindInvocationMessage(json, binder);
case ResultMessageType:
return BindResultMessage(json, binder);
case CompletionMessageType:
return BindCompletionMessage(json, binder);
default:
throw new FormatException($"Unknown message type: {type}");
}
}
catch (JsonReaderException jrex)
{
throw new FormatException("Error reading JSON.", jrex);
}
}
}
private void WriteMessage(HubMessage message, Stream stream)
{
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
{
switch (message)
{
case InvocationMessage m:
WriteInvocationMessage(m, writer);
break;
case StreamItemMessage m:
WriteResultMessage(m, writer);
break;
case CompletionMessage m:
WriteCompletionMessage(m, writer);
break;
default:
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
}
}
}
private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
WriteHubMessageCommon(message, writer, CompletionMessageType);
if (!string.IsNullOrEmpty(message.Error))
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(message.Error);
}
else if (message.HasResult)
{
writer.WritePropertyName(ResultPropertyName);
_payloadSerializer.Serialize(writer, message.Result);
}
writer.WriteEndObject();
}
private void WriteResultMessage(StreamItemMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
WriteHubMessageCommon(message, writer, ResultMessageType);
writer.WritePropertyName(ResultPropertyName);
_payloadSerializer.Serialize(writer, message.Item);
writer.WriteEndObject();
}
private void WriteInvocationMessage(InvocationMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
WriteHubMessageCommon(message, writer, InvocationMessageType);
writer.WritePropertyName(TargetPropertyName);
writer.WriteValue(message.Target);
if (message.NonBlocking)
{
writer.WritePropertyName(NonBlockingPropertyName);
writer.WriteValue(message.NonBlocking);
}
writer.WritePropertyName(ArgumentsPropertyName);
writer.WriteStartArray();
foreach (var argument in message.Arguments)
{
_payloadSerializer.Serialize(writer, argument);
}
writer.WriteEndArray();
writer.WriteEndObject();
}
private static void WriteHubMessageCommon(HubMessage message, JsonTextWriter writer, int type)
{
writer.WritePropertyName(InvocationIdPropertyName);
writer.WriteValue(message.InvocationId);
writer.WritePropertyName(TypePropertyName);
writer.WriteValue(type);
}
private InvocationMessage BindInvocationMessage(JObject json, IInvocationBinder binder)
{
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
var target = GetRequiredProperty<string>(json, TargetPropertyName, JTokenType.String);
var nonBlocking = GetOptionalProperty<bool>(json, NonBlockingPropertyName, JTokenType.Boolean);
var args = GetRequiredProperty<JArray>(json, ArgumentsPropertyName, JTokenType.Array);
var paramTypes = binder.GetParameterTypes(target);
var arguments = new object[args.Count];
if (paramTypes.Length != arguments.Length)
{
throw new FormatException($"Invocation provides {arguments.Length} argument(s) but target expects {paramTypes.Length}.");
}
for (var i = 0; i < paramTypes.Length; i++)
{
var paramType = paramTypes[i];
// TODO(anurse): We can add some DI magic here to allow users to provide their own serialization
// Related Bug: https://github.com/aspnet/SignalR/issues/261
arguments[i] = args[i].ToObject(paramType, _payloadSerializer);
}
return new InvocationMessage(invocationId, nonBlocking, target, arguments);
}
private StreamItemMessage BindResultMessage(JObject json, IInvocationBinder binder)
{
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
var result = GetRequiredProperty<JToken>(json, ResultPropertyName);
var returnType = binder.GetReturnType(invocationId);
return new StreamItemMessage(invocationId, result?.ToObject(returnType, _payloadSerializer));
}
private CompletionMessage BindCompletionMessage(JObject json, IInvocationBinder binder)
{
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
var error = GetOptionalProperty<string>(json, ErrorPropertyName, JTokenType.String);
var resultProp = json.Property(ResultPropertyName);
if (error != null && resultProp != null)
{
throw new FormatException("The 'error' and 'result' properties are mutually exclusive.");
}
if (resultProp == null)
{
return new CompletionMessage(invocationId, error, result: null, hasResult: false);
}
else
{
var returnType = binder.GetReturnType(invocationId);
var payload = resultProp.Value?.ToObject(returnType, _payloadSerializer);
return new CompletionMessage(invocationId, error, result: payload, hasResult: true);
}
}
private T GetOptionalProperty<T>(JObject json, string property, JTokenType expectedType = JTokenType.None, T defaultValue = default(T))
{
var prop = json[property];
if (prop == null)
{
return defaultValue;
}
return GetValue<T>(property, expectedType, prop);
}
private T GetRequiredProperty<T>(JObject json, string property, JTokenType expectedType = JTokenType.None)
{
var prop = json[property];
if (prop == null)
{
throw new FormatException($"Missing required property '{property}'.");
}
return GetValue<T>(property, expectedType, prop);
}
private static T GetValue<T>(string property, JTokenType expectedType, JToken prop)
{
if (expectedType != JTokenType.None && prop.Type != expectedType)
{
throw new FormatException($"Expected '{property}' to be of type {expectedType}.");
}
return prop.Value<T>();
}
}
}

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class StreamItemMessage : HubMessage
{
public object Item { get; }
public StreamItemMessage(string invocationId, object item) : base(invocationId)
{
Item = item;
}
public override string ToString()
{
return $"StreamItem {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Item)}: {Item ?? "<<null>>"} }}";
}
}
}

View File

@ -1,16 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
public static class InvocationAdapterExtensions
{
public static Task<InvocationMessage> ReadMessageAsync(this IInvocationAdapter self, Stream stream, IInvocationBinder binder) => self.ReadMessageAsync(stream, binder, CancellationToken.None);
public static Task WriteMessageAsync(this IInvocationAdapter self, InvocationMessage message, Stream stream) => self.WriteMessageAsync(message, stream, CancellationToken.None);
}
}

View File

@ -1,19 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.SignalR
{
public class InvocationDescriptor : InvocationMessage
{
public string Method { get; set; }
public object[] Arguments { get; set; }
public override string ToString()
{
return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})";
}
}
}

View File

@ -1,17 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
public class InvocationResultDescriptor : InvocationMessage
{
public object Result { get; set; }
public string Error { get; set; }
}
}

View File

@ -1,89 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Internal;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.SignalR
{
public class JsonNetInvocationAdapter : IInvocationAdapter
{
private JsonSerializer _serializer = new JsonSerializer();
public JsonNetInvocationAdapter()
{
}
public Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
{
var reader = new JsonTextReader(new StreamReader(stream));
// REVIEW: Task.Run()
return Task.Run<InvocationMessage>(() =>
{
cancellationToken.ThrowIfCancellationRequested();
var json = _serializer.Deserialize<JObject>(reader);
if (json == null)
{
return null;
}
// Determine the type of the message
if (json["Result"] != null)
{
// It's a result
return BindInvocationResultDescriptor(json, binder, cancellationToken);
}
else
{
return BindInvocationDescriptor(json, binder, cancellationToken);
}
}, cancellationToken);
}
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
{
var writer = new JsonTextWriter(new StreamWriter(stream));
_serializer.Serialize(writer, message);
writer.Flush();
return TaskCache.CompletedTask;
}
private InvocationDescriptor BindInvocationDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken)
{
var invocation = new InvocationDescriptor
{
Id = json.Value<string>("Id"),
Method = json.Value<string>("Method"),
};
var paramTypes = binder.GetParameterTypes(invocation.Method);
invocation.Arguments = new object[paramTypes.Length];
var args = json.Value<JArray>("Arguments");
for (var i = 0; i < paramTypes.Length; i++)
{
var paramType = paramTypes[i];
invocation.Arguments[i] = args[i].ToObject(paramType, _serializer);
}
return invocation;
}
private InvocationResultDescriptor BindInvocationResultDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken)
{
var id = json.Value<string>("Id");
var returnType = binder.GetReturnType(id);
var result = new InvocationResultDescriptor()
{
Id = id,
Result = returnType == null ? null : json["Result"].ToObject(returnType, _serializer),
Error = json.Value<string>("Error")
};
return result;
}
}
}

View File

@ -9,11 +9,20 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>aspnetcore;signalr</PackageTags>
<EnableApiCheck>false</EnableApiCheck>
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<Compile Include="../Common/IOutputExtensions.cs" Link="IOutputExtensions.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(JsonNetVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
</ItemGroup>
</Project>

View File

@ -13,7 +13,6 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR\Microsoft.AspNetCore.SignalR.csproj" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="StackExchange.Redis.StrongName" Version="$(RedisVersion)" />
</ItemGroup>

View File

@ -5,53 +5,80 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
private readonly ConnectionList _connections = new ConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly InvocationAdapterRegistry _registry;
private readonly ConnectionMultiplexer _redisServerConnection;
private readonly ISubscriber _bus;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private readonly RedisOptions _options;
public RedisHubLifetimeManager(InvocationAdapterRegistry registry,
ILoggerFactory loggerFactory,
// This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
private readonly JsonSerializer _serializer = new JsonSerializer
{
// We need to serialize objects "full-fidelity", even if it is noisy, so we preserve the original types
TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
TypeNameHandling = TypeNameHandling.All,
Formatting = Formatting.None
};
private long _nextInvocationId = 0;
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options)
{
_loggerFactory = loggerFactory;
_registry = registry;
_logger = logger;
_options = options.Value;
var writer = new LoggerTextWriter(loggerFactory.CreateLogger<RedisHubLifetimeManager<THub>>());
var writer = new LoggerTextWriter(logger);
_logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e))));
_redisServerConnection = _options.Connect(writer);
if (_redisServerConnection.IsConnected)
{
_logger.LogInformation("Connected to redis");
}
else
{
// TODO: We could support reconnecting, like old SignalR does.
throw new InvalidOperationException("Connection to redis failed.");
}
_bus = _redisServerConnection.GetSubscriber();
var previousBroadcastTask = TaskCache.CompletedTask;
var previousBroadcastTask = Task.CompletedTask;
_bus.Subscribe(typeof(THub).FullName, async (c, data) =>
var channelName = typeof(THub).FullName;
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
await previousBroadcastTask;
_logger.LogTrace("Received message from redis channel {channel}", channelName);
var message = DeserializeMessage(data);
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
tasks.Add(WriteAsync(connection, data));
tasks.Add(WriteAsync(connection, message));
}
previousBroadcastTask = Task.WhenAll(tasks);
@ -60,80 +87,68 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override Task InvokeAllAsync(string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName, message);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
}
private async Task PublishAsync(string channel, InvocationDescriptor message)
private async Task PublishAsync(string channel, HubMessage hubMessage)
{
// TODO: What format??
var invocationAdapter = _registry.GetInvocationAdapter("json");
// BAD
using (var ms = new MemoryStream())
byte[] payload;
using (var stream = new MemoryStream())
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
{
await invocationAdapter.WriteMessageAsync(message, ms);
await _bus.PublishAsync(channel, ms.ToArray());
_serializer.Serialize(writer, hubMessage);
await writer.FlushAsync();
payload = stream.ToArray();
}
_logger.LogTrace("Publishing message to redis channel {channel}", channel);
await _bus.PublishAsync(channel, payload);
}
public override Task OnConnectedAsync(Connection connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
var connectionTask = TaskCache.CompletedTask;
var userTask = TaskCache.CompletedTask;
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
var connectionTask = Task.CompletedTask;
var userTask = Task.CompletedTask;
_connections.Add(connection);
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = TaskCache.CompletedTask;
var previousConnectionTask = Task.CompletedTask;
_logger.LogInformation("Subscribing to connection channel: {channel}", connectionChannel);
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
await previousConnectionTask;
previousConnectionTask = WriteAsync(connection, data);
var message = DeserializeMessage(data);
previousConnectionTask = WriteAsync(connection, message);
});
@ -142,14 +157,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
redisSubscriptions.Add(userChannel);
var previousUserTask = TaskCache.CompletedTask;
var previousUserTask = Task.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{
await previousUserTask;
previousUserTask = WriteAsync(connection, data);
var message = DeserializeMessage(data);
previousUserTask = WriteAsync(connection, message);
});
}
@ -162,16 +179,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var tasks = new List<Task>();
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>("redis_subscriptions");
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>(RedisSubscriptionsMetadataName);
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
_logger.LogInformation("Unsubscribing from channel: {channel}", subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
@ -190,7 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet<string>());
var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
lock (groupNames)
{
@ -210,8 +228,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
var previousTask = TaskCache.CompletedTask;
var previousTask = Task.CompletedTask;
_logger.LogInformation("Subscribing to group channel: {channel}", groupChannel);
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
// Since this callback is async, we await the previous task then
@ -219,10 +238,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
// want to do concurrent writes to the outgoing connections
await previousTask;
var message = DeserializeMessage(data);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
tasks.Add(WriteAsync(groupConnection, data));
tasks.Add(WriteAsync(groupConnection, message));
}
previousTask = Task.WhenAll(tasks);
@ -244,7 +265,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
lock (groupNames)
@ -260,6 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
if (group.Connections.Count == 0)
{
_logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
@ -275,9 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose();
}
private async Task WriteAsync(Connection connection, byte[] data)
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
{
var message = new Message(data, MessageType.Text, endOfMessage: true);
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
var data = await protocol.WriteToArrayAsync(hubMessage);
var message = new Message(data, protocol.MessageType, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
@ -288,6 +312,23 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
private string GetInvocationId()
{
var invocationId = Interlocked.Increment(ref _nextInvocationId);
return invocationId.ToString();
}
private HubMessage DeserializeMessage(RedisValue data)
{
HubMessage message;
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
{
message = (HubMessage)_serializer.Deserialize(reader);
}
return message;
}
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;

View File

@ -3,43 +3,37 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private long _nextInvocationId = 0;
private readonly ConnectionList _connections = new ConnectionList();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
{
_registry = registry;
}
public override Task AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
lock (groups)
{
groups.Add(groupName);
}
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groups == null)
{
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
lock (groups)
@ -47,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR
groups.Remove(groupName);
}
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public override Task InvokeAllAsync(string methodName, object[] args)
@ -58,11 +52,7 @@ namespace Microsoft.AspNetCore.SignalR
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _connections)
@ -72,9 +62,7 @@ namespace Microsoft.AspNetCore.SignalR
continue;
}
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
tasks.Add(WriteAsync(connection, invocationAdapter, message));
tasks.Add(WriteAsync(connection, message));
}
return Task.WhenAll(tasks);
@ -84,22 +72,16 @@ namespace Microsoft.AspNetCore.SignalR
{
var connection = _connections[connectionId];
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return WriteAsync(connection, invocationAdapter, message);
return WriteAsync(connection, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
return groups?.Contains(groupName) == true;
});
}
@ -115,21 +97,20 @@ namespace Microsoft.AspNetCore.SignalR
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
private static async Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocation)
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
{
var stream = new MemoryStream();
await invocationAdapter.WriteMessageAsync(invocation, stream);
var message = new Message(stream.ToArray(), MessageType.Text, endOfMessage: true);
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
var payload = await protocol.WriteToArrayAsync(hubMessage);
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
@ -139,5 +120,11 @@ namespace Microsoft.AspNetCore.SignalR
}
}
}
private string GetInvocationId()
{
var invocationId = Interlocked.Increment(ref _nextInvocationId);
return invocationId.ToString();
}
}
}

View File

@ -3,7 +3,6 @@
using System;
using System.Threading.Tasks;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR
{
@ -62,12 +61,12 @@ namespace Microsoft.AspNetCore.SignalR
public virtual Task OnConnectedAsync()
{
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public virtual Task OnDisconnectedAsync(Exception exception)
{
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
protected virtual void Dispose(bool disposing)

View File

@ -1,15 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
public abstract class InvocationMessage
public static class HubConnectionMetadataNames
{
public string Id { get; set; }
public static readonly string HubProtocol = nameof(HubProtocol);
public static readonly string Groups = nameof(Groups);
}
}

View File

@ -1,13 +1,14 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
@ -19,12 +20,12 @@ namespace Microsoft.AspNetCore.SignalR
public class HubEndPoint<THub> : HubEndPoint<THub, IClientProxy> where THub : Hub<IClientProxy>
{
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub> hubContext,
InvocationAdapterRegistry registry,
IOptions<EndPointOptions<HubEndPoint<THub, IClientProxy>>> endPointOptions,
ILogger<HubEndPoint<THub>> logger,
IServiceScopeFactory serviceScopeFactory)
: base(lifetimeManager, hubContext, registry, endPointOptions, logger, serviceScopeFactory)
: base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory)
{
}
}
@ -36,19 +37,19 @@ namespace Microsoft.AspNetCore.SignalR
private readonly HubLifetimeManager<THub> _lifetimeManager;
private readonly IHubContext<THub, TClient> _hubContext;
private readonly ILogger<HubEndPoint<THub, TClient>> _logger;
private readonly InvocationAdapterRegistry _registry;
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubProtocolResolver _protocolResolver;
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub, TClient> hubContext,
InvocationAdapterRegistry registry,
IOptions<EndPointOptions<HubEndPoint<THub, TClient>>> endPointOptions,
ILogger<HubEndPoint<THub, TClient>> logger,
IServiceScopeFactory serviceScopeFactory)
{
_protocolResolver = protocolResolver;
_lifetimeManager = lifetimeManager;
_hubContext = hubContext;
_registry = registry;
_logger = logger;
_serviceScopeFactory = serviceScopeFactory;
@ -59,6 +60,11 @@ namespace Microsoft.AspNetCore.SignalR
{
try
{
// Resolve the Hub Protocol for the connection and store it in metadata
// Other components, outside the Hub, may need to know what protocol is in use
// for a particular connection, so we store it here.
connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection);
await _lifetimeManager.OnConnectedAsync(connection);
await RunHubAsync(connection);
}
@ -140,14 +146,13 @@ namespace Microsoft.AspNetCore.SignalR
private async Task DispatchMessagesAsync(Connection connection)
{
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
// We use these for error handling. Since we dispatch multiple hub invocations
// in parallel, we need a way to communicate failure back to the main processing loop. The
// cancellation token is used to stop reading from the channel, the tcs
// is used to get the exception so we can bubble it up the stack
var cts = new CancellationTokenSource();
var tcs = new TaskCompletionSource<object>();
var completion = new TaskCompletionSource<object>();
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
try
{
@ -155,100 +160,94 @@ namespace Microsoft.AspNetCore.SignalR
{
while (connection.Transport.Input.TryRead(out var incomingMessage))
{
InvocationDescriptor invocationDescriptor;
var inputStream = new MemoryStream(incomingMessage.Payload);
var hubMessage = protocol.ParseMessage(incomingMessage.Payload, this);
// TODO: Handle receiving InvocationResultDescriptor
invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor;
// Is there a better way of detecting that a connection was closed?
if (invocationDescriptor == null)
switch (hubMessage)
{
break;
}
case InvocationMessage invocationMessage:
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Received hub invocation: {invocation}", invocationMessage);
}
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor);
}
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion);
break;
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
var ignore = ProcessInvocation(connection, invocationAdapter, invocationDescriptor, cts, tcs);
// Other kind of message we weren't expecting
default:
_logger.LogError("Received unsupported message of type '{messageType}'", hubMessage.GetType().FullName);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
}
}
}
}
catch (OperationCanceledException)
{
// Await the task so the exception bubbles up to the caller
await tcs.Task;
await completion.Task;
}
}
private async Task ProcessInvocation(Connection connection,
IInvocationAdapter invocationAdapter,
InvocationDescriptor invocationDescriptor,
CancellationTokenSource cts,
TaskCompletionSource<object> tcs)
IHubProtocol protocol,
InvocationMessage invocationMessage,
CancellationTokenSource dispatcherCancellation,
TaskCompletionSource<object> dispatcherCompletion)
{
try
{
// If an unexpected exception occurs then we want to kill the entire connection
// by ending the processing loop
await Execute(connection, invocationAdapter, invocationDescriptor);
await Execute(connection, protocol, invocationMessage);
}
catch (Exception ex)
{
// Set the exception on the task completion source
tcs.TrySetException(ex);
dispatcherCompletion.TrySetException(ex);
// Cancel reading operation
cts.Cancel();
dispatcherCancellation.Cancel();
}
}
private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor)
private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage)
{
InvocationResultDescriptor result;
HubMethodDescriptor descriptor;
if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor))
if (!_methods.TryGetValue(invocationMessage.Target, out descriptor))
{
result = await Invoke(descriptor, connection, invocationDescriptor);
// Send an error to the client. Then let the normal completion process occur
_logger.LogError("Unknown hub method '{method}'", invocationMessage.Target);
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
}
else
{
// If there's no method then return a failed response for this request
result = new InvocationResultDescriptor
{
Id = invocationDescriptor.Id,
Error = $"Unknown hub method '{invocationDescriptor.Method}'"
};
_logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method);
}
// TODO: Pool memory
var outStream = new MemoryStream();
await invocationAdapter.WriteMessageAsync(result, outStream);
var outMessage = new Message(outStream.ToArray(), MessageType.Text, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
if (connection.Transport.Output.TryWrite(outMessage))
{
break;
}
var result = await Invoke(descriptor, connection, invocationMessage);
await SendMessageAsync(connection, protocol, result);
}
}
private async Task<InvocationResultDescriptor> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor)
private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage)
{
var invocationResult = new InvocationResultDescriptor
{
Id = invocationDescriptor.Id
};
var payload = await protocol.WriteToArrayAsync(hubMessage);
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
if (connection.Transport.Output.TryWrite(message))
{
return;
}
}
// Output is closed. Cancel this invocation completely
_logger.LogWarning("Outbound channel was closed while trying to write hub message");
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
}
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage)
{
var methodExecutor = descriptor.MethodExecutor;
using (var scope = _serviceScopeFactory.CreateScope())
@ -265,37 +264,35 @@ namespace Microsoft.AspNetCore.SignalR
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
await (Task)methodExecutor.Execute(hub, invocationDescriptor.Arguments);
await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments);
}
else
{
result = await methodExecutor.ExecuteAsync(hub, invocationDescriptor.Arguments);
result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments);
}
}
else
{
result = methodExecutor.Execute(hub, invocationDescriptor.Arguments);
result = methodExecutor.Execute(hub, invocationMessage.Arguments);
}
invocationResult.Result = result;
return CompletionMessage.WithResult(invocationMessage.InvocationId, result);
}
catch (TargetInvocationException ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.InnerException.Message;
return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.Message;
return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message);
}
finally
{
hubActivator.Release(hub);
}
}
return invocationResult;
}
private void InitializeHub(THub hub, Connection connection)

View File

@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public class DefaultHubProtocolResolver : IHubProtocolResolver
{
public IHubProtocol GetProtocol(Connection connection)
{
// TODO: Allow customization of this serializer!
return new JsonHubProtocol(new JsonSerializer());
}
}
}

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public interface IHubProtocolResolver
{
IHubProtocol GetProtocol(Connection connection);
}
}

View File

@ -1,32 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR
{
public class InvocationAdapterRegistry
{
private readonly IServiceProvider _serviceProvider;
private readonly SignalROptions _options;
public InvocationAdapterRegistry(IOptions<SignalROptions> options, IServiceProvider serviceProvider)
{
_options = options.Value;
_serviceProvider = serviceProvider;
}
public IInvocationAdapter GetInvocationAdapter(string format)
{
Type type;
if (_options._invocationMappings.TryGetValue(format, out type))
{
return _serviceProvider.GetRequiredService(type) as IInvocationAdapter;
}
return null;
}
}
}

View File

@ -14,7 +14,6 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.ClosedGenericMatcher.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.ObjectMethodExecutor.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="Newtonsoft.Json" Version="$(JsonNetVersion)" />

View File

@ -1,9 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.Extensions.DependencyInjection
{
@ -13,26 +12,13 @@ namespace Microsoft.Extensions.DependencyInjection
{
services.AddSockets();
services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>));
services.AddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver));
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddSingleton<IConfigureOptions<SignalROptions>, SignalROptionsSetup>();
services.AddSingleton<JsonNetInvocationAdapter>();
services.AddSingleton<InvocationAdapterRegistry>();
services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>));
services.AddRouting();
return new SignalRBuilder(services);
}
public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action<SignalROptions> setupAction)
{
return services.AddSignalR().AddSignalROptions(setupAction);
}
public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action<SignalROptions> setupAction)
{
builder.Services.Configure(setupAction);
return builder;
}
}
}

View File

@ -1,20 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
public class SignalROptions
{
internal readonly Dictionary<string, Type> _invocationMappings = new Dictionary<string, Type>();
public void RegisterInvocationAdapter<TInvocationAdapter>(string format) where TInvocationAdapter : IInvocationAdapter
{
_invocationMappings[format] = typeof(TInvocationAdapter);
}
}
}

View File

@ -1,17 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
namespace Microsoft.Extensions.DependencyInjection
{
public class SignalROptionsSetup : IConfigureOptions<SignalROptions>
{
public void Configure(SignalROptions options)
{
options.RegisterInvocationAdapter<JsonNetInvocationAdapter>("json");
}
}
}

View File

@ -10,7 +10,6 @@ using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@ -58,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
return t;
}).Unwrap();
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public async Task StopAsync()

View File

@ -17,7 +17,6 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="System.Text.Formatting" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -8,7 +8,6 @@ using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@ -61,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
return t;
}).Unwrap();
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
private async Task OpenConnection(IChannelConnection<SendMessage, Message> application, Uri url, CancellationToken cancellationToken)

View File

@ -12,6 +12,10 @@
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<Compile Include="../Common/IOutputExtensions.cs" Link="IOutputExtensions.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="System.Binary.Base64" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets
{
public static class ConnectionMetadataNames
{
public static readonly string Format = nameof(Format);
public static readonly string Transport = nameof(Transport);
public static readonly string HttpContext = nameof(HttpContext);
}
}

View File

@ -6,7 +6,7 @@ using Microsoft.AspNetCore.Sockets;
namespace Microsoft.Extensions.DependencyInjection
{
public static class EndpointDependencyInjectionExtensions
public static class EndPointDependencyInjectionExtensions
{
public static IServiceCollection AddEndPoint<TEndPoint>(this IServiceCollection services) where TEndPoint : EndPoint
{

View File

@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Sockets
{
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
state.Connection.Metadata["transport"] = TransportType.LongPolling;
state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
}
@ -232,13 +232,10 @@ namespace Microsoft.AspNetCore.Sockets
private ConnectionState CreateConnection(HttpContext context)
{
var state = _manager.CreateConnection();
var format = (string)context.Request.Query[ConnectionMetadataNames.Format];
state.Connection.User = context.User;
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
state.Connection.Metadata[typeof(HttpContext)] = context;
state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format;
return state;
}
@ -355,7 +352,7 @@ namespace Microsoft.AspNetCore.Sockets
var messages = ParseSendBatch(ref reader, messageFormat);
// REVIEW: Do we want to return a specific status code here if the connection has ended?
_logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count);
_logger.LogDebug("Received batch of {count} message(s) in '/send'", messages.Count);
foreach (var message in messages)
{
while (!state.Application.Output.TryWrite(message))
@ -379,11 +376,11 @@ namespace Microsoft.AspNetCore.Sockets
connectionState.Connection.User = context.User;
var transport = connectionState.Connection.Metadata.Get<TransportType?>("transport");
var transport = connectionState.Connection.Metadata.Get<TransportType?>(ConnectionMetadataNames.Transport);
if (transport == null)
{
connectionState.Connection.Metadata["transport"] = transportType;
connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
}
else if (transport != transportType)
{

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public async Task DisposeAsync()
{
Task disposeTask = TaskCache.CompletedTask;
Task disposeTask = Task.CompletedTask;
try
{
@ -69,8 +69,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
Connection.Dispose();
Application.Dispose();
var applicationTask = ApplicationTask ?? TaskCache.CompletedTask;
var transportTask = TransportTask ?? TaskCache.CompletedTask;
var applicationTask = ApplicationTask ?? Task.CompletedTask;
var transportTask = TransportTask ?? Task.CompletedTask;
disposeTask = WaitOnTasks(applicationTask, transportTask);
}

View File

@ -18,7 +18,6 @@
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="System.Reflection.TypeExtensions" Version="$(CoreFxVersion)" />
<PackageReference Include="System.Security.Claims" Version="$(CoreFxVersion)" />
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />

View File

@ -134,7 +134,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// Is this a frame we care about?
if (!frame.Opcode.IsMessage())
{
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
LogFrame("Receiving", frame);

View File

@ -4,7 +4,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Internal;
namespace Microsoft.Extensions.WebSockets.Internal
{
@ -135,7 +134,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
connection.ExecuteAsync((frame, _) =>
{
messageHandler(frame);
return TaskCache.CompletedTask;
return Task.CompletedTask;
}, null);
/// <summary>
@ -149,7 +148,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
connection.ExecuteAsync((frame, s) =>
{
messageHandler(frame, s);
return TaskCache.CompletedTask;
return Task.CompletedTask;
}, state);
/// <summary>

View File

@ -17,7 +17,6 @@
<PackageReference Include="System.ValueTuple" Version="$(CoreFxVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
</ItemGroup>
</Project>

View File

@ -1,7 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Tests.Common
@ -10,36 +11,48 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common
{
private const int DefaultTimeout = 5000;
public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout)
public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
}
public static async Task OrTimeout(this Task task, TimeSpan timeout)
public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(timeout));
if (completed != task)
{
throw new TimeoutException();
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
await task;
}
public static Task<T> OrTimeout<T>(this Task<T> task, int milliseconds = DefaultTimeout)
public static Task<T> OrTimeout<T>(this Task<T> task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
}
public static async Task<T> OrTimeout<T>(this Task<T> task, TimeSpan timeout)
public static async Task<T> OrTimeout<T>(this Task<T> task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(timeout));
if (completed != task)
{
throw new TimeoutException();
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
return await task;
}
private static string GetMessage(string memberName, string filePath, int? lineNumber)
{
if (!string.IsNullOrEmpty(memberName))
{
return $"Operation in {memberName} timed out at {filePath}:{lineNumber}";
}
else
{
return "Operation timed out";
}
}
}
}

View File

@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var httpClient = _testServer.CreateClient())
{
var connection = new HubConnection(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), loggerFactory);
var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory);
try
{
await connection.StartAsync(TransportType.LongPolling, httpClient);
@ -133,13 +133,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
tcs.TrySetResult((string)a[0]);
});
await connection.Invoke<Task>("CallEcho", originalMessage);
await connection.Invoke<Task>("CallEcho", originalMessage).OrTimeout();
Assert.Equal(originalMessage, await tcs.Task.OrTimeout());
}
finally
{
await connection.DisposeAsync();
await connection.DisposeAsync().OrTimeout();
}
}
}

View File

@ -0,0 +1,156 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
// This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer).
// We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks
// don't cause problems.
public class HubConnectionProtocolTests
{
[Fact]
public async Task InvokeSendsAnInvocationMessage()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.Invoke("Foo", typeof(void));
var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout();
Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task InvokeCompletedWhenCompletionMessageReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.Invoke("Foo", typeof(void));
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
await invokeTask.OrTimeout();
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task InvokeYieldsResultWhenCompletionMessageReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.Invoke<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout();
Assert.Equal(42, await invokeTask.OrTimeout());
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.Invoke<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => invokeTask).OrTimeout();
Assert.Equal("An error occurred", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
// This will fail (intentionally) when we support streaming!
public async Task InvokeFailsWithErrorWhenStreamingItemReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.Invoke<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, result = 42 }).OrTimeout();
var ex = await Assert.ThrowsAsync<NotSupportedException>(() => invokeTask).OrTimeout();
Assert.Equal("Streaming method results are not supported", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
var handlerCalled = new TaskCompletionSource<object[]>();
try
{
await hubConnection.StartAsync();
hubConnection.On("Foo", new[] { typeof(int), typeof(string), typeof(float) }, (a) => handlerCalled.TrySetResult(a));
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 1, "Foo", 2.0f } }).OrTimeout();
var results = await handlerCalled.Task.OrTimeout();
Assert.Equal(3, results.Length);
Assert.Equal(1, results[0]);
Assert.Equal("Foo", results[1]);
Assert.Equal(2.0f, results[2]);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
}
}

View File

@ -2,12 +2,13 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Buffers;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Client.Tests;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
@ -25,14 +26,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public void CannotCreateHubConnectionWithNullUrl()
{
var exception = Assert.Throws<ArgumentNullException>(
() => new HubConnection((Uri)null, Mock.Of<IInvocationAdapter>(), Mock.Of<ILoggerFactory>()));
() => new HubConnection((Uri)null, Mock.Of<ILoggerFactory>()));
Assert.Equal("url", exception.ParamName);
}
[Fact]
public async Task CanDisposeNotStartedHubConnection()
{
await new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory())
await new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory())
.DisposeAsync();
}
@ -50,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
try
{
@ -81,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
await hubConnection.StartAsync(TransportType.LongPolling, httpClient);
await hubConnection.DisposeAsync();
@ -110,7 +111,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var exception =
await Assert.ThrowsAsync<InvalidOperationException>(async () => await hubConnection.Invoke<int>("test"));
Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
}
[Fact]
@ -119,12 +120,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var mockConnection = new Mock<IConnection>();
var exception = new InvalidOperationException();
var mockInvocationAdapter = new Mock<IInvocationAdapter>();
mockInvocationAdapter
.Setup(a => a.WriteMessageAsync(It.IsAny<InvocationMessage>(), It.IsAny<Stream>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromException(exception));
var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
var mockProtocol = MockHubProtocol.Throw(exception);
var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
await hubConnection.StartAsync(TransportType.All, httpClient: null);
var actualException =
@ -146,7 +143,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
try
{
var connectedEventRaisedTcs = new TaskCompletionSource<object>();
@ -177,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
var closedEventTcs = new TaskCompletionSource<Exception>();
hubConnection.Closed += e => closedEventTcs.SetResult(e);
@ -197,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
.Returns(Task.FromResult<object>(null));
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
await hubConnection.DisposeAsync();
@ -216,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
.Returns(Task.FromResult<object>(null));
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
@ -235,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, exception))
.Returns(Task.FromResult<object>(null));
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
@ -250,23 +247,69 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var mockConnection = new Mock<IConnection>();
var invocationDescriptor = new InvocationDescriptor
{
Method = "NonExistingMethod123",
Arguments = new object[] { true, "arg2", 123 },
Id = Guid.NewGuid().ToString()
};
var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 });
var mockInvocationAdapter = new Mock<IInvocationAdapter>();
mockInvocationAdapter
.Setup(a => a.ReadMessageAsync(It.IsAny<Stream>(), It.IsAny<IInvocationBinder>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult((InvocationMessage)invocationDescriptor));
var mockProtocol = MockHubProtocol.ReturnOnParse(invocation);
var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { }, MessageType.Text });
mockInvocationAdapter.Verify(a => a.ReadMessageAsync(It.IsAny<Stream>(), It.IsAny<IInvocationBinder>(), It.IsAny<CancellationToken>()), Times.Once());
Assert.Equal(1, mockProtocol.ParseCalls);
}
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
private class MockHubProtocol : IHubProtocol
{
private HubMessage _parsed;
private Exception _error;
public int ParseCalls { get; private set; } = 0;
public int WriteCalls { get; private set; } = 0;
public MessageType MessageType => MessageType.Text;
public static MockHubProtocol ReturnOnParse(HubMessage parsed)
{
return new MockHubProtocol
{
_parsed = parsed
};
}
public static MockHubProtocol Throw(Exception error)
{
return new MockHubProtocol
{
_error = error
};
}
public HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder)
{
ParseCalls += 1;
if (_error != null)
{
throw _error;
}
if (_parsed != null)
{
return _parsed;
}
throw new InvalidOperationException("No Parsed Message provided");
}
public bool TryWriteMessage(HubMessage message, IOutput output)
{
WriteCalls += 1;
if (_error != null)
{
throw _error;
}
return true;
}
}
}
}

View File

@ -0,0 +1,116 @@
using System;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
internal class TestConnection : IConnection
{
private TaskCompletionSource<object> _started = new TaskCompletionSource<object>();
private TaskCompletionSource<object> _disposed = new TaskCompletionSource<object>();
private Channel<Message> _sentMessages = Channel.CreateUnbounded<Message>();
private Channel<Message> _receivedMessages = Channel.CreateUnbounded<Message>();
private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource();
private Task _receiveLoop;
public event Action Connected;
public event Action<byte[], MessageType> Received;
public event Action<Exception> Closed;
public Task Started => _started.Task;
public Task Disposed => _disposed.Task;
public ReadableChannel<Message> SentMessages => _sentMessages.In;
public WritableChannel<Message> ReceivedMessages => _receivedMessages.Out;
public TestConnection()
{
_receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token);
}
public Task DisposeAsync()
{
_disposed.TrySetResult(null);
_receiveShutdownToken.Cancel();
return _receiveLoop;
}
public async Task SendAsync(byte[] data, MessageType type, CancellationToken cancellationToken)
{
if(!_started.Task.IsCompleted)
{
throw new InvalidOperationException("Connection must be started before SendAsync can be called");
}
var message = new Message(data, type, endOfMessage: true);
while (await _sentMessages.Out.WaitToWriteAsync(cancellationToken))
{
if (_sentMessages.Out.TryWrite(message))
{
return;
}
}
throw new ObjectDisposedException("Unable to send message, underlying channel was closed");
}
public Task StartAsync(ITransportFactory transportFactory, HttpClient httpClient)
{
_started.TrySetResult(null);
Connected?.Invoke();
return Task.CompletedTask;
}
public async Task<string> ReadSentTextMessageAsync()
{
var message = await SentMessages.ReadAsync();
if (message.Type != MessageType.Text)
{
throw new InvalidOperationException($"Unexpected message of type: {message.Type}");
}
return Encoding.UTF8.GetString(message.Payload);
}
public Task ReceiveJsonMessage(object jsonObject)
{
var json = JsonConvert.SerializeObject(jsonObject, Formatting.None);
var bytes = Encoding.UTF8.GetBytes(json);
var message = new Message(bytes, MessageType.Text);
return _receivedMessages.Out.WriteAsync(message);
}
private async Task ReceiveLoopAsync(CancellationToken token)
{
try
{
while (!token.IsCancellationRequested)
{
while (await _receivedMessages.In.WaitToReadAsync(token))
{
while (_receivedMessages.In.TryRead(out var message))
{
Received?.Invoke(message.Payload, message.Type);
}
}
}
Closed?.Invoke(null);
}
catch (OperationCanceledException)
{
// Do nothing, we were just asked to shut down.
Closed?.Invoke(null);
}
catch (Exception ex)
{
Closed?.Invoke(ex);
}
}
}
}

View File

@ -0,0 +1,257 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
public class JsonHubProtocolTests
{
public static IEnumerable<object[]> ProtocolTestData => new[]
{
new object[] { new InvocationMessage("123", true, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"nonBlocking\":true,\"arguments\":[1,\"Foo\",2.0]}" },
new object[] { new InvocationMessage("123", false, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" },
new object[] { new InvocationMessage("123", false, "Target", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[true]}" },
new object[] { new InvocationMessage("123", false, "Target", new object[] { null }), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[null]}" },
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}]}" },
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}]}" },
new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":1}" },
new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":\"Foo\"}" },
new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":2.0}" },
new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":true}" },
new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":null}" },
new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
new object[] { CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":1}" },
new object[] { CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":\"Foo\"}" },
new object[] { CompletionMessage.WithResult("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":2.0}" },
new object[] { CompletionMessage.WithResult("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":true}" },
new object[] { CompletionMessage.WithResult("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":null}" },
new object[] { CompletionMessage.WithError("123", "Whoops!"), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"error\":\"Whoops!\"}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
};
[Theory]
[MemberData(nameof(ProtocolTestData))]
public async Task WriteMessage(HubMessage message, bool camelCase, NullValueHandling nullValueHandling, string expectedOutput)
{
var jsonSerializer = new JsonSerializer
{
NullValueHandling = nullValueHandling,
ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
};
var protocol = new JsonHubProtocol(jsonSerializer);
var encoded = await protocol.WriteToArrayAsync(message);
var json = Encoding.UTF8.GetString(encoded);
Assert.Equal(expectedOutput, json);
}
[Theory]
[MemberData(nameof(ProtocolTestData))]
public void ParseMessage(HubMessage expectedMessage, bool camelCase, NullValueHandling nullValueHandling, string input)
{
var jsonSerializer = new JsonSerializer
{
NullValueHandling = nullValueHandling,
ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
};
var binder = new TestBinder(expectedMessage);
var protocol = new JsonHubProtocol(jsonSerializer);
var message = protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder);
Assert.Equal(expectedMessage, message, TestEqualityComparer.Instance);
}
[Theory]
[InlineData("", "Error reading JSON.")]
[InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
[InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
[InlineData("[42]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
[InlineData("{}", "Missing required property 'type'.")]
[InlineData("{'type':1}", "Missing required property 'invocationId'.")]
[InlineData("{'type':1,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
[InlineData("{'type':1,'invocationId':'42','target':42}", "Expected 'target' to be of type String.")]
[InlineData("{'type':1,'invocationId':'42','target':'foo'}", "Missing required property 'arguments'.")]
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':{}}", "Expected 'arguments' to be of type Array.")]
[InlineData("{'type':2}", "Missing required property 'invocationId'.")]
[InlineData("{'type':2,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
[InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'result'.")]
[InlineData("{'type':3}", "Missing required property 'invocationId'.")]
[InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
[InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")]
[InlineData("{'type':4}", "Unknown message type: 4")]
[InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")]
public void InvalidMessages(string input, string expectedMessage)
{
var binder = new TestBinder();
var protocol = new JsonHubProtocol(new JsonSerializer());
var ex = Assert.Throws<FormatException>(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
Assert.Equal(expectedMessage, ex.Message);
}
[Theory]
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")]
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[42, 'foo'],'nonBlocking':42}", "Expected 'nonBlocking' to be of type Boolean.")]
[InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")]
public void InvalidMessagesWithBinder(string input, string expectedMessage)
{
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
var protocol = new JsonHubProtocol(new JsonSerializer());
var ex = Assert.Throws<FormatException>(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
Assert.Equal(expectedMessage, ex.Message);
}
private class CustomObject : IEquatable<CustomObject>
{
// Not intended to be a full set of things, just a smattering of sample serializations
public string StringProp => "SignalR!";
public double DoubleProp => 6.2831853071;
public int IntProp => 42;
public DateTime DateTimeProp => new DateTime(2017, 4, 11);
public object NullProp => null;
public override bool Equals(object obj)
{
return obj is CustomObject o && Equals(o);
}
public override int GetHashCode()
{
// This is never used in a hash table
return 0;
}
public bool Equals(CustomObject right)
{
// This allows the comparer below to properly compare the object in the test.
return string.Equals(StringProp, right.StringProp, StringComparison.Ordinal) &&
DoubleProp == right.DoubleProp &&
IntProp == right.IntProp &&
DateTime.Equals(DateTimeProp, right.DateTimeProp) &&
NullProp == right.NullProp;
}
}
// Binder that works based on the expected message argument/result types :)
private class TestBinder : IInvocationBinder
{
private readonly Type[] _paramTypes;
private readonly Type _returnType;
public TestBinder(HubMessage expectedMessage)
{
switch(expectedMessage)
{
case InvocationMessage i:
_paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray();
break;
case StreamItemMessage s:
_returnType = s.Item?.GetType() ?? typeof(object);
break;
case CompletionMessage c:
_returnType = c.Result?.GetType() ?? typeof(object);
break;
}
}
public TestBinder() : this(null, null) { }
public TestBinder(Type[] paramTypes) : this(paramTypes, null) { }
public TestBinder(Type returnType) : this(null, returnType) {}
public TestBinder(Type[] paramTypes, Type returnType)
{
_paramTypes = paramTypes;
_returnType = returnType;
}
public Type[] GetParameterTypes(string methodName)
{
if (_paramTypes != null)
{
return _paramTypes;
}
throw new InvalidOperationException("Unexpected binder call");
}
public Type GetReturnType(string invocationId)
{
if (_returnType != null)
{
return _returnType;
}
throw new InvalidOperationException("Unexpected binder call");
}
}
private class TestEqualityComparer : IEqualityComparer<HubMessage>
{
public static readonly TestEqualityComparer Instance = new TestEqualityComparer();
private TestEqualityComparer() { }
public bool Equals(HubMessage x, HubMessage y)
{
if (!string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal))
{
return false;
}
return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y);
}
private bool CompletionMessagesEqual(HubMessage x, HubMessage y)
{
return x is CompletionMessage left && y is CompletionMessage right &&
string.Equals(left.Error, right.Error, StringComparison.Ordinal) &&
Equals(left.Result, right.Result) &&
left.HasResult == right.HasResult;
}
private bool StreamItemMessagesEqual(HubMessage x, HubMessage y)
{
return x is StreamItemMessage left && y is StreamItemMessage right &&
Equals(left.Item, right.Item);
}
private bool InvocationMessagesEqual(HubMessage x, HubMessage y)
{
return x is InvocationMessage left && y is InvocationMessage right &&
string.Equals(left.Target, right.Target, StringComparison.Ordinal) &&
Enumerable.SequenceEqual(left.Arguments, right.Arguments) &&
left.NonBlocking == right.NonBlocking;
}
public int GetHashCode(HubMessage obj)
{
// We never use these in a hash-table
return 0;
}
}
}
}

View File

@ -0,0 +1,31 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<TargetFrameworks>netcoreapp2.0;net46</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">netcoreapp2.0</TargetFrameworks>
<!-- TODO remove when https://github.com/Microsoft/vstest/issues/428 is resolved -->
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
</PropertyGroup>
<ItemGroup>
<Compile Remove="Protocol\**" />
<EmbeddedResource Remove="Protocol\**" />
<None Remove="Protocol\**" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitVersion)" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
<PackageReference Include="System.ValueTuple" Version="$(CoreFxVersion)" />
</ItemGroup>
<ItemGroup>
<Service Include="{82a7f48d-3b50-4b1e-b82e-3ada8210c358}" />
</ItemGroup>
</Project>

View File

@ -1,14 +1,14 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Moq;
using Xunit;
using Microsoft.AspNetCore.SignalR.Tests.Common;
namespace Microsoft.AspNetCore.SignalR.Tests
{
@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var trackDispose = new TrackDispose();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose));
var endPoint = serviceProvider.GetService<HubEndPoint<TestHub>>();
var endPoint = serviceProvider.GetService<HubEndPoint<DisposeTrackingHub>>();
using (var client = new TestClient(serviceProvider))
{
@ -127,10 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.TaskValueMethod)).OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.TaskValueMethod)).OrTimeout()).Result;
// json serializer makes this a long
Assert.Equal(42L, result.Result);
Assert.Equal(42L, result);
// kill the connection
client.Dispose();
@ -150,10 +150,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>("echo", "hello").OrTimeout();
var result = (await client.InvokeAsync("echo", "hello").OrTimeout()).Result;
Assert.Null(result.Error);
Assert.Equal("hello", result.Result);
Assert.Equal("hello", result);
// kill the connection
client.Dispose();
await endPointTask.OrTimeout();
}
}
[Theory]
[InlineData(nameof(MethodHub.MethodThatThrows))]
[InlineData(nameof(MethodHub.MethodThatYieldsFailedTask))]
public async Task HubMethodCanThrowOrYieldFailedTask(string methodName)
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient(serviceProvider))
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = (await client.InvokeAsync(methodName).OrTimeout());
Assert.Equal("BOOM!", result.Error);
// kill the connection
client.Dispose();
@ -173,10 +196,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.ValueMethod)).OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.ValueMethod)).OrTimeout()).Result;
// json serializer makes this a long
Assert.Equal(43L, result.Result);
Assert.Equal(43L, result);
// kill the connection
client.Dispose();
@ -196,9 +219,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.VoidMethod)).OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.VoidMethod)).OrTimeout()).Result;
Assert.Null(result.Result);
Assert.Null(result);
// kill the connection
client.Dispose();
@ -218,9 +241,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout()).Result;
Assert.Equal("32, 42, m, string", result.Result);
Assert.Equal("32, 42, m, string", result);
// kill the connection
client.Dispose();
@ -240,9 +263,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(InheritedHub.BaseMethod), "string").OrTimeout();
var result = (await client.InvokeAsync(nameof(InheritedHub.BaseMethod), "string").OrTimeout()).Result;
Assert.Equal("string", result.Result);
Assert.Equal("string", result);
// kill the connection
client.Dispose();
@ -262,9 +285,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(InheritedHub.VirtualMethod), 10).OrTimeout();
var result = (await client.InvokeAsync(nameof(InheritedHub.VirtualMethod), 10).OrTimeout()).Result;
Assert.Equal(0L, result.Result);
Assert.Equal(0L, result);
// kill the connection
client.Dispose();
@ -284,7 +307,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
var result = await client.InvokeAsync(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error);
@ -326,15 +349,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.Invoke(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
await firstClient.SendInvocationAsync(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
foreach (var result in await Task.WhenAll(
firstClient.Read<InvocationDescriptor>(),
secondClient.Read<InvocationDescriptor>()).OrTimeout())
firstClient.Read(),
secondClient.Read()).OrTimeout())
{
Assert.Equal("Broadcast", result.Method);
Assert.Equal(1, result.Arguments.Length);
Assert.Equal("test", result.Arguments[0]);
var invocation = Assert.IsType<InvocationMessage>(result);
Assert.Equal("Broadcast", invocation.Target);
Assert.Equal(1, invocation.Arguments.Length);
Assert.Equal("test", invocation.Arguments[0]);
}
// kill the connections
@ -360,23 +384,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
var result = await firstClient.Invoke<InvocationResultDescriptor>(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
var result = (await firstClient.InvokeAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout()).Result;
// check that 'firstConnection' hasn't received the group send
Assert.Null(result.Id);
Assert.Null(firstClient.TryRead());
// check that 'secondConnection' hasn't received the group send
Assert.Null(await secondClient.TryRead<InvocationDescriptor>().OrTimeout());
Assert.Null(secondClient.TryRead());
result = await secondClient.Invoke<InvocationResultDescriptor>(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout();
Assert.Null(result.Id);
result = (await secondClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout()).Result;
await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
await firstClient.SendInvocationAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
// check that 'secondConnection' has received the group send
var descriptor = await secondClient.Read<InvocationDescriptor>().OrTimeout();
Assert.Equal("Send", descriptor.Method);
Assert.Equal(1, descriptor.Arguments.Length);
Assert.Equal("test", descriptor.Arguments[0]);
var hubMessage = await secondClient.Read().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
Assert.Equal("Send", invocation.Target);
Assert.Equal(1, invocation.Arguments.Length);
Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@ -397,7 +422,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
await client.Invoke(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
await client.SendInvocationAsync(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
// kill the connection
client.Dispose();
@ -421,13 +446,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.Invoke(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
await firstClient.SendInvocationAsync(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
// check that 'secondConnection' has received the group send
var result = await secondClient.Read<InvocationDescriptor>().OrTimeout();
Assert.Equal("Send", result.Method);
Assert.Equal(1, result.Arguments.Length);
Assert.Equal("test", result.Arguments[0]);
var hubMessage = await secondClient.Read().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
Assert.Equal("Send", invocation.Target);
Assert.Equal(1, invocation.Arguments.Length);
Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@ -452,13 +478,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.Invoke(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
await firstClient.SendInvocationAsync(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
// check that 'secondConnection' has received the group send
var result = await secondClient.Read<InvocationDescriptor>().OrTimeout();
Assert.Equal("Send", result.Method);
Assert.Equal(1, result.Arguments.Length);
Assert.Equal("test", result.Arguments[0]);
var hubMessage = await secondClient.Read().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
Assert.Equal("Send", invocation.Target);
Assert.Equal(1, invocation.Arguments.Length);
Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@ -575,7 +602,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public override Task OnDisconnectedAsync(Exception e)
{
return TaskCache.CompletedTask;
return Task.CompletedTask;
}
public void MethodThatThrows()
{
throw new InvalidOperationException("BOOM!");
}
public Task MethodThatYieldsFailedTask()
{
return Task.FromException(new InvalidOperationException("BOOM!"));
}
}
@ -611,11 +648,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
private class TestHub : Hub
private class DisposeTrackingHub : Hub
{
private TrackDispose _trackDispose;
public TestHub(TrackDispose trackDispose)
public DisposeTrackingHub(TrackDispose trackDispose)
{
_trackDispose = trackDispose;
}

View File

@ -2,29 +2,30 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class TestClient : IDisposable
public class TestClient : IDisposable, IInvocationBinder
{
private static int _id;
private IInvocationAdapter _adapter;
private IHubProtocol _protocol;
private CancellationTokenSource _cts;
private TestBinder _binder;
public Connection Connection;
public IChannelConnection<Message> Application { get; }
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
public TestClient(IServiceProvider serviceProvider, string format = "json")
public TestClient(IServiceProvider serviceProvider)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
@ -33,62 +34,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var transport = ChannelConnection.Create<Message>(input: transportToApplication, output: applicationToTransport);
Connection = new Connection(Guid.NewGuid().ToString(), transport);
Connection.Metadata["formatType"] = format;
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
_adapter = invocationAdapter.GetInvocationAdapter(format);
_binder = new TestBinder();
_protocol = new JsonHubProtocol(new JsonSerializer());
_cts = new CancellationTokenSource();
}
public async Task<T> Invoke<T>(string methodName, params object[] args) where T : InvocationMessage
public async Task<CompletionMessage> InvokeAsync(string methodName, params object[] args)
{
await Invoke(methodName, args);
var invocationId = await SendInvocationAsync(methodName, args);
return await Read<T>();
}
public async Task Invoke(string methodName, params object[] args)
{
var stream = new MemoryStream();
await _adapter.WriteMessageAsync(new InvocationDescriptor
while (true)
{
Arguments = args,
Method = methodName
},
stream);
var message = await Read();
await Application.Output.WriteAsync(new Message(stream.ToArray(), MessageType.Binary, endOfMessage: true));
}
public async Task<T> Read<T>() where T : InvocationMessage
{
while (await Application.Input.WaitToReadAsync(_cts.Token))
{
var value = await TryRead<T>();
if (value != null)
if (!string.Equals(message.InvocationId, invocationId))
{
return value;
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
}
if (message == null)
{
throw new InvalidOperationException("Connection aborted!");
}
switch (message)
{
case StreamItemMessage result:
throw new NotSupportedException("TestClient does not support streaming!");
case CompletionMessage completion:
return completion;
default:
throw new NotSupportedException("TestClient does not support receiving invocations!");
}
}
return null;
}
public async Task<T> TryRead<T>() where T : InvocationMessage
public async Task<string> SendInvocationAsync(string methodName, params object[] args)
{
Message message;
if (Application.Input.TryRead(out message))
{
var value = await _adapter.ReadMessageAsync(new MemoryStream(message.Payload), _binder);
return value as T;
}
var invocationId = GetInvocationId();
var payload = await _protocol.WriteToArrayAsync(new InvocationMessage(invocationId, nonBlocking: false, target: methodName, arguments: args));
await Application.Output.WriteAsync(new Message(payload, _protocol.MessageType, endOfMessage: true));
return invocationId;
}
public async Task<HubMessage> Read()
{
while (true)
{
var message = TryRead();
if (message == null)
{
if (!await Application.Input.WaitToReadAsync())
{
return null;
}
}
else
{
return message;
}
}
}
public HubMessage TryRead()
{
if (Application.Input.TryRead(out var message))
{
return _protocol.ParseMessage(message.Payload, this);
}
return null;
}
@ -98,18 +117,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Connection.Dispose();
}
private class TestBinder : IInvocationBinder
private static string GetInvocationId()
{
public Type[] GetParameterTypes(string methodName)
{
// TODO: Possibly support actual client methods
return new[] { typeof(object) };
}
return Guid.NewGuid().ToString("N");
}
public Type GetReturnType(string invocationId)
{
return typeof(object);
}
Type[] IInvocationBinder.GetParameterTypes(string methodName)
{
// TODO: Possibly support actual client methods
return new[] { typeof(object) };
}
Type IInvocationBinder.GetReturnType(string invocationId)
{
return typeof(object);
}
}
}