Implement new Hub Protocol (Part Deux) (#390)
* convert to new protocol * removed InvocationDescriptorRegistry because we're not yet sure about custom protocols * update SocialWeather sample * Moving ts client to using new protocol * make the functional tests a little easier to run on ctrl-f5
This commit is contained in:
parent
6cf6feed64
commit
991c1d8517
|
|
@ -1,6 +1,6 @@
|
|||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio 15
|
||||
VisualStudioVersion = 15.0.26228.9
|
||||
VisualStudioVersion = 15.0.26411.1
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
|
||||
EndProject
|
||||
|
|
@ -74,6 +74,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "client-ts", "client-ts", "{
|
|||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Microbenchmarks", "test\Microsoft.AspNetCore.SignalR.Microbenchmarks\Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj", "{96771B3F-4D18-41A7-A75B-FF38E76AAC89}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
|
|
@ -180,6 +182,10 @@ Global
|
|||
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
|
@ -211,5 +217,6 @@ Global
|
|||
{333526A4-633B-491A-AC45-CC62A0012D1C} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B}
|
||||
{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
{96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<Project>
|
||||
<Project>
|
||||
<PropertyGroup>
|
||||
<AspNetCoreIntegrationTestingVersion>0.4.0-*</AspNetCoreIntegrationTestingVersion>
|
||||
<SystemMemoryVersion>4.4.0-*</SystemMemoryVersion>
|
||||
<AspNetCoreVersion>2.0.0-*</AspNetCoreVersion>
|
||||
<CoreFxLabsVersion>0.1.0-*</CoreFxLabsVersion>
|
||||
<CoreFxVersion>4.3.0</CoreFxVersion>
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { DataReceived, ConnectionClosed } from "../Microsoft.AspNetCore.SignalR.
|
|||
import { TransportType, ITransport } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports"
|
||||
|
||||
describe("HubConnection", () => {
|
||||
it("completes pending invocations when stopped", async (done) => {
|
||||
it("completes pending invocations when stopped", async done => {
|
||||
let connection: IConnection = {
|
||||
start(transportType: TransportType | ITransport): Promise<void> {
|
||||
return Promise.resolve();
|
||||
|
|
@ -27,18 +27,18 @@ describe("HubConnection", () => {
|
|||
let hubConnection = new HubConnection(connection);
|
||||
var invokePromise = hubConnection.invoke("testMethod");
|
||||
hubConnection.stop();
|
||||
invokePromise
|
||||
.then(() => {
|
||||
fail();
|
||||
done();
|
||||
})
|
||||
.catch((error: Error) => {
|
||||
expect(error.message).toBe("Invocation cancelled due to connection being closed.");
|
||||
done();
|
||||
});
|
||||
|
||||
try {
|
||||
await invokePromise;
|
||||
fail();
|
||||
}
|
||||
catch (e) {
|
||||
expect(e.message).toBe("Invocation cancelled due to connection being closed.");
|
||||
}
|
||||
done();
|
||||
});
|
||||
|
||||
it("completes pending invocations when connection is lost", async (done) => {
|
||||
it("completes pending invocations when connection is lost", async done => {
|
||||
let connection: IConnection = {
|
||||
start(transportType: TransportType | ITransport): Promise<void> {
|
||||
return Promise.resolve();
|
||||
|
|
@ -60,17 +60,92 @@ describe("HubConnection", () => {
|
|||
|
||||
let hubConnection = new HubConnection(connection);
|
||||
var invokePromise = hubConnection.invoke("testMethod");
|
||||
invokePromise
|
||||
.then(() => {
|
||||
fail();
|
||||
done();
|
||||
})
|
||||
.catch((error: Error) => {
|
||||
expect(error.message).toBe("Connection lost");
|
||||
done();
|
||||
});
|
||||
|
||||
// Typically this would be called by the transport
|
||||
connection.onClosed(new Error("Connection lost"));
|
||||
|
||||
try {
|
||||
await invokePromise;
|
||||
fail();
|
||||
}
|
||||
catch (e) {
|
||||
expect(e.message).toBe("Connection lost");
|
||||
}
|
||||
done();
|
||||
});
|
||||
|
||||
it("sends invocations as nonblocking", async done => {
|
||||
let dataSent: string;
|
||||
let connection: IConnection = {
|
||||
start(transportType: TransportType): Promise<void> {
|
||||
return Promise.resolve();
|
||||
},
|
||||
|
||||
send(data: any): Promise<void> {
|
||||
dataSent = data;
|
||||
return Promise.resolve();
|
||||
},
|
||||
|
||||
stop(): void {
|
||||
if (this.onClosed) {
|
||||
this.onClosed();
|
||||
}
|
||||
},
|
||||
|
||||
onDataReceived: null,
|
||||
onClosed: null
|
||||
};
|
||||
|
||||
let hubConnection = new HubConnection(connection);
|
||||
let invokePromise = hubConnection.invoke("testMethod");
|
||||
|
||||
expect(JSON.parse(dataSent).nonblocking).toBe(false);
|
||||
|
||||
// will clean pending promises
|
||||
connection.onClosed();
|
||||
|
||||
try {
|
||||
await invokePromise;
|
||||
fail(); // exception is expected because the call has not completed
|
||||
}
|
||||
catch (e) {
|
||||
}
|
||||
done();
|
||||
});
|
||||
|
||||
it("rejects streaming responses", async done => {
|
||||
let connection: IConnection = {
|
||||
start(transportType: TransportType): Promise<void> {
|
||||
return Promise.resolve();
|
||||
},
|
||||
|
||||
send(data: any): Promise<void> {
|
||||
return Promise.resolve();
|
||||
},
|
||||
|
||||
stop(): void {
|
||||
if (this.onClosed) {
|
||||
this.onClosed();
|
||||
}
|
||||
},
|
||||
|
||||
onDataReceived: null,
|
||||
onClosed: null
|
||||
};
|
||||
|
||||
let hubConnection = new HubConnection(connection);
|
||||
let invokePromise = hubConnection.invoke("testMethod");
|
||||
|
||||
connection.onDataReceived("{ \"type\": 2, \"invocationId\": \"0\", \"result\": null }");
|
||||
connection.onClosed();
|
||||
|
||||
try {
|
||||
await invokePromise;
|
||||
fail();
|
||||
}
|
||||
catch (e) {
|
||||
expect(e.message).toBe("Streaming is not supported.");
|
||||
}
|
||||
|
||||
done();
|
||||
});
|
||||
});
|
||||
|
|
@ -3,16 +3,31 @@ import { IConnection } from "./IConnection"
|
|||
import { Connection } from "./Connection"
|
||||
import { TransportType } from "./Transports"
|
||||
|
||||
interface InvocationDescriptor {
|
||||
readonly Id: string;
|
||||
readonly Method: string;
|
||||
readonly Arguments: Array<any>;
|
||||
|
||||
const enum MessageType {
|
||||
Invocation = 1,
|
||||
Result,
|
||||
Completion
|
||||
}
|
||||
|
||||
interface InvocationResultDescriptor {
|
||||
readonly Id: string;
|
||||
readonly Error: string;
|
||||
readonly Result: any;
|
||||
interface HubMessage {
|
||||
readonly type: MessageType;
|
||||
readonly invocationId: string;
|
||||
}
|
||||
|
||||
interface InvocationMessage extends HubMessage {
|
||||
readonly target: string;
|
||||
readonly arguments: Array<any>;
|
||||
readonly nonblocking?: boolean;
|
||||
}
|
||||
|
||||
interface ResultMessage extends HubMessage {
|
||||
readonly result?: any;
|
||||
}
|
||||
|
||||
interface CompletionMessage extends HubMessage {
|
||||
readonly error?: string;
|
||||
readonly result?: any;
|
||||
}
|
||||
|
||||
export { Connection } from "./Connection"
|
||||
|
|
@ -20,7 +35,7 @@ export { TransportType } from "./Transports"
|
|||
|
||||
export class HubConnection {
|
||||
private connection: IConnection;
|
||||
private callbacks: Map<string, (invocationDescriptor: InvocationResultDescriptor) => void>;
|
||||
private callbacks: Map<string, (invocationUpdate: CompletionMessage|ResultMessage) => void>;
|
||||
private methods: Map<string, (...args: any[]) => void>;
|
||||
private id: number;
|
||||
private connectionClosedCallback: ConnectionClosed;
|
||||
|
|
@ -40,7 +55,7 @@ export class HubConnection {
|
|||
this.onConnectionClosed(error);
|
||||
}
|
||||
|
||||
this.callbacks = new Map<string, (invocationDescriptor: InvocationResultDescriptor) => void>();
|
||||
this.callbacks = new Map<string, (invocationEvent: CompletionMessage|ResultMessage) => void>();
|
||||
this.methods = new Map<string, (...args: any[]) => void>();
|
||||
this.id = 0;
|
||||
}
|
||||
|
|
@ -51,34 +66,49 @@ export class HubConnection {
|
|||
if (!data) {
|
||||
return;
|
||||
}
|
||||
var descriptor = JSON.parse(data);
|
||||
if (descriptor.Method === undefined) {
|
||||
let invocationResult: InvocationResultDescriptor = descriptor;
|
||||
let callback = this.callbacks.get(invocationResult.Id);
|
||||
if (callback != null) {
|
||||
callback(invocationResult);
|
||||
this.callbacks.delete(invocationResult.Id);
|
||||
|
||||
var message = JSON.parse(data);
|
||||
switch (message.type) {
|
||||
case MessageType.Invocation:
|
||||
this.InvokeClientMethod(<InvocationMessage>message);
|
||||
break;
|
||||
case MessageType.Result:
|
||||
// TODO: Streaming (MessageType.Result) currently not supported - callback will throw
|
||||
case MessageType.Completion:
|
||||
let callback = this.callbacks.get(message.invocationId);
|
||||
if (callback != null) {
|
||||
callback(message);
|
||||
this.callbacks.delete(message.invocationId);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
console.log("Invalid message type: " + data);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private InvokeClientMethod(invocationMessage: InvocationMessage) {
|
||||
let method = this.methods.get(invocationMessage.target);
|
||||
if (method) {
|
||||
method.apply(this, invocationMessage.arguments);
|
||||
if (!invocationMessage.nonblocking) {
|
||||
// TODO: send result back to the server?
|
||||
}
|
||||
}
|
||||
else {
|
||||
let invocation: InvocationDescriptor = descriptor;
|
||||
let method = this.methods[invocation.Method];
|
||||
if (method != null) {
|
||||
// TODO: bind? args?
|
||||
method.apply(this, invocation.Arguments);
|
||||
}
|
||||
console.log(`No client method with the name '${invocationMessage.target}' found.`);
|
||||
}
|
||||
}
|
||||
|
||||
private onConnectionClosed(error: Error) {
|
||||
let errorInvocationResult = {
|
||||
Id: "-1",
|
||||
Error: error ? error.message : "Invocation cancelled due to connection being closed.",
|
||||
Result: null
|
||||
} as InvocationResultDescriptor;
|
||||
let errorCompletionMessage = <CompletionMessage>{
|
||||
type: MessageType.Completion,
|
||||
invocationId: "-1",
|
||||
error: error ? error.message : "Invocation cancelled due to connection being closed.",
|
||||
};
|
||||
|
||||
this.callbacks.forEach(callback => {
|
||||
callback(errorInvocationResult);
|
||||
callback(errorCompletionMessage);
|
||||
});
|
||||
this.callbacks.clear();
|
||||
|
||||
|
|
@ -99,19 +129,27 @@ export class HubConnection {
|
|||
let id = this.id;
|
||||
this.id++;
|
||||
|
||||
let invocationDescriptor: InvocationDescriptor = {
|
||||
"Id": id.toString(),
|
||||
"Method": methodName,
|
||||
"Arguments": args
|
||||
let invocationDescriptor: InvocationMessage = {
|
||||
type: MessageType.Invocation,
|
||||
invocationId: id.toString(),
|
||||
target: methodName,
|
||||
arguments: args,
|
||||
nonblocking: false
|
||||
};
|
||||
|
||||
let p = new Promise<any>((resolve, reject) => {
|
||||
this.callbacks.set(invocationDescriptor.Id, (invocationResult: InvocationResultDescriptor) => {
|
||||
if (invocationResult.Error != null) {
|
||||
reject(new Error(invocationResult.Error));
|
||||
this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: CompletionMessage | ResultMessage) => {
|
||||
if (invocationEvent.type === MessageType.Completion) {
|
||||
let completionMessage = <CompletionMessage>invocationEvent;
|
||||
if (completionMessage.error) {
|
||||
reject(new Error(completionMessage.error));
|
||||
}
|
||||
else {
|
||||
resolve(completionMessage.result);
|
||||
}
|
||||
}
|
||||
else {
|
||||
resolve(invocationResult.Result);
|
||||
reject(new Error("Streaming is not supported."))
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -119,7 +157,7 @@ export class HubConnection {
|
|||
this.connection.send(JSON.stringify(invocationDescriptor))
|
||||
.catch(e => {
|
||||
reject(e);
|
||||
this.callbacks.delete(invocationDescriptor.Id);
|
||||
this.callbacks.delete(invocationDescriptor.invocationId);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -127,7 +165,7 @@ export class HubConnection {
|
|||
}
|
||||
|
||||
on(methodName: string, method: (...args: any[]) => void) {
|
||||
this.methods[methodName] = method;
|
||||
this.methods.set(methodName, method);
|
||||
}
|
||||
|
||||
set onClosed(callback: ConnectionClosed) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
|
|
@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
|
|||
app.UseDeveloperExceptionPage();
|
||||
}
|
||||
|
||||
app.UseStaticFiles();
|
||||
app.UseFileServer();
|
||||
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("/echo"));
|
||||
app.UseSignalR(routes =>
|
||||
{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>SignalR Tests</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>SignalR Tests</h1>
|
||||
<ul>
|
||||
<li><a href="connectionTests.html">Connection Tests</a></li>
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -5,7 +5,6 @@ using System;
|
|||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.AspNetCore.SignalR.Client;
|
||||
using Microsoft.Extensions.CommandLineUtils;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -33,7 +32,7 @@ namespace ClientSample
|
|||
var loggerFactory = new LoggerFactory();
|
||||
|
||||
Console.WriteLine("Connecting to {0}", baseUrl);
|
||||
var connection = new HubConnection(new Uri(baseUrl), new JsonNetInvocationAdapter(), loggerFactory);
|
||||
var connection = new HubConnection(new Uri(baseUrl), loggerFactory);
|
||||
try
|
||||
{
|
||||
await connection.StartAsync();
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
|
@ -22,6 +22,7 @@ namespace SocialWeather
|
|||
|
||||
public void OnConnectedAsync(Connection connection)
|
||||
{
|
||||
connection.Metadata[ConnectionMetadataNames.Format] = "json";
|
||||
_connectionList.Add(connection);
|
||||
}
|
||||
|
||||
|
|
@ -34,11 +35,11 @@ namespace SocialWeather
|
|||
{
|
||||
foreach (var connection in _connectionList)
|
||||
{
|
||||
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType"));
|
||||
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>(ConnectionMetadataNames.Format));
|
||||
var ms = new MemoryStream();
|
||||
await formatter.WriteAsync(data, ms);
|
||||
|
||||
var context = (HttpContext)connection.Metadata[typeof(HttpContext)];
|
||||
var context = (HttpContext)connection.Metadata[ConnectionMetadataNames.HttpContext];
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? MessageType.Binary
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ namespace SocketsSample.EndPoints
|
|||
{
|
||||
Connections.Add(connection);
|
||||
|
||||
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
|
||||
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
|
||||
|
||||
try
|
||||
{
|
||||
|
|
@ -37,7 +37,7 @@ namespace SocketsSample.EndPoints
|
|||
{
|
||||
Connections.Remove(connection);
|
||||
|
||||
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
|
||||
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,91 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class LineInvocationAdapter : IInvocationAdapter
|
||||
{
|
||||
public async Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
|
||||
{
|
||||
var streamReader = new StreamReader(stream);
|
||||
var line = await streamReader.ReadLineAsync();
|
||||
if (line == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var values = line.Split(',');
|
||||
|
||||
var type = values[0].Substring(0, 2);
|
||||
var id = values[0].Substring(2);
|
||||
|
||||
if (type.Equals("RI"))
|
||||
{
|
||||
var resultType = values[1].Substring(0, 1);
|
||||
var result = values[1].Substring(1);
|
||||
return new InvocationResultDescriptor()
|
||||
{
|
||||
Id = id,
|
||||
Result = resultType.Equals("E") ? null : result,
|
||||
Error = resultType.Equals("E") ? result : null,
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
var method = values[1].Substring(1);
|
||||
|
||||
return new InvocationDescriptor
|
||||
{
|
||||
Id = id,
|
||||
Method = method,
|
||||
Arguments = values.Skip(2).Zip(binder.GetParameterTypes(method), (v, t) => Convert.ChangeType(v, t)).ToArray()
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
var invocationDescriptor = message as InvocationDescriptor;
|
||||
if (invocationDescriptor != null)
|
||||
{
|
||||
return WriteInvocationDescriptorAsync(invocationDescriptor, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
return WriteInvocationResultAsync((InvocationResultDescriptor)message, stream);
|
||||
}
|
||||
}
|
||||
|
||||
private Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
|
||||
{
|
||||
var msg = $"CI{invocationDescriptor.Id},M{invocationDescriptor.Method},{string.Join(",", invocationDescriptor.Arguments.Select(a => a.ToString()))}\n";
|
||||
return WriteAsync(msg, stream);
|
||||
}
|
||||
|
||||
private Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
|
||||
{
|
||||
if (string.IsNullOrEmpty(resultDescriptor.Error))
|
||||
{
|
||||
return WriteAsync($"RI{resultDescriptor.Id},E{resultDescriptor.Error}\n", stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
return WriteAsync($"RI{resultDescriptor.Id},R{(resultDescriptor.Result != null ? resultDescriptor.Result.ToString() : string.Empty)}\n", stream);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task WriteAsync(string msg, Stream stream)
|
||||
{
|
||||
var writer = new StreamWriter(stream);
|
||||
await writer.WriteAsync(msg);
|
||||
await writer.FlushAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace SocketsSample.Protobuf
|
||||
{
|
||||
public class ProtobufInvocationAdapter : IInvocationAdapter
|
||||
{
|
||||
private IServiceProvider _serviceProvider;
|
||||
|
||||
public ProtobufInvocationAdapter(IServiceProvider serviceProvider)
|
||||
{
|
||||
_serviceProvider = serviceProvider;
|
||||
}
|
||||
|
||||
public Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
|
||||
{
|
||||
return Task.Run(() => CreateInvocationMessageInt(stream, binder));
|
||||
}
|
||||
|
||||
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private Task<InvocationMessage> CreateInvocationMessageInt(Stream stream, IInvocationBinder binder)
|
||||
{
|
||||
var inputStream = new CodedInputStream(stream, leaveOpen: true);
|
||||
var messageKind = new RpcMessageKind();
|
||||
inputStream.ReadMessage(messageKind);
|
||||
if(messageKind.MessageKind == RpcMessageKind.Types.Kind.Invocation)
|
||||
{
|
||||
return CreateInvocationDescriptorInt(inputStream, binder);
|
||||
}
|
||||
else
|
||||
{
|
||||
return CreateInvocationResultDescriptorInt(inputStream, binder);
|
||||
}
|
||||
}
|
||||
|
||||
private Task<InvocationMessage> CreateInvocationResultDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
|
||||
{
|
||||
throw new NotImplementedException("Not yet implemented for Protobuf");
|
||||
}
|
||||
|
||||
private Task<InvocationMessage> CreateInvocationDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
|
||||
{
|
||||
var invocationHeader = new RpcInvocationHeader();
|
||||
inputStream.ReadMessage(invocationHeader);
|
||||
var argumentTypes = binder.GetParameterTypes(invocationHeader.Name);
|
||||
|
||||
var invocationDescriptor = new InvocationDescriptor();
|
||||
invocationDescriptor.Method = invocationHeader.Name;
|
||||
invocationDescriptor.Id = invocationHeader.Id.ToString();
|
||||
invocationDescriptor.Arguments = new object[argumentTypes.Length];
|
||||
|
||||
var primitiveParser = PrimitiveValue.Parser;
|
||||
|
||||
for (var i = 0; i < argumentTypes.Length; i++)
|
||||
{
|
||||
if (typeof(int) == argumentTypes[i])
|
||||
{
|
||||
var value = new PrimitiveValue();
|
||||
inputStream.ReadMessage(value);
|
||||
invocationDescriptor.Arguments[i] = value.Int32Value;
|
||||
}
|
||||
else if (typeof(string) == argumentTypes[i])
|
||||
{
|
||||
var value = new PrimitiveValue();
|
||||
inputStream.ReadMessage(value);
|
||||
invocationDescriptor.Arguments[i] = value.StringValue;
|
||||
}
|
||||
else
|
||||
{
|
||||
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
|
||||
invocationDescriptor.Arguments[i] = serializer.GetValue(inputStream, argumentTypes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return Task.FromResult<InvocationMessage>(invocationDescriptor);
|
||||
}
|
||||
|
||||
public async Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
|
||||
{
|
||||
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
|
||||
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result });
|
||||
|
||||
var resultHeader = new RpcInvocationResultHeader
|
||||
{
|
||||
Id = int.Parse(resultDescriptor.Id),
|
||||
HasResult = resultDescriptor.Result != null
|
||||
};
|
||||
|
||||
if (resultDescriptor.Error != null)
|
||||
{
|
||||
resultHeader.Error = resultDescriptor.Error;
|
||||
}
|
||||
|
||||
outputStream.WriteMessage(resultHeader);
|
||||
|
||||
if (string.IsNullOrEmpty(resultHeader.Error) && resultDescriptor.Result != null)
|
||||
{
|
||||
var result = resultDescriptor.Result;
|
||||
|
||||
if (result.GetType() == typeof(int))
|
||||
{
|
||||
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result });
|
||||
}
|
||||
else if (result.GetType() == typeof(string))
|
||||
{
|
||||
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result });
|
||||
}
|
||||
else
|
||||
{
|
||||
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
|
||||
var message = serializer.GetMessage(result);
|
||||
outputStream.WriteMessage(message);
|
||||
}
|
||||
}
|
||||
|
||||
outputStream.Flush();
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
|
||||
public async Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
|
||||
{
|
||||
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
|
||||
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation });
|
||||
|
||||
var invocationHeader = new RpcInvocationHeader()
|
||||
{
|
||||
Id = 0,
|
||||
Name = invocationDescriptor.Method,
|
||||
NumArgs = invocationDescriptor.Arguments.Length
|
||||
};
|
||||
|
||||
outputStream.WriteMessage(invocationHeader);
|
||||
|
||||
foreach (var arg in invocationDescriptor.Arguments)
|
||||
{
|
||||
if (arg.GetType() == typeof(int))
|
||||
{
|
||||
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg });
|
||||
}
|
||||
else if (arg.GetType() == typeof(string))
|
||||
{
|
||||
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg });
|
||||
}
|
||||
else
|
||||
{
|
||||
var serializer = _serviceProvider.GetRequiredService<ProtobufSerializer>();
|
||||
var message = serializer.GetMessage(arg);
|
||||
outputStream.WriteMessage(message);
|
||||
}
|
||||
}
|
||||
|
||||
outputStream.Flush();
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,848 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
// source: RpcInvocation.proto
|
||||
#pragma warning disable 1591, 0612, 3021
|
||||
#region Designer generated code
|
||||
|
||||
using pb = global::Google.Protobuf;
|
||||
using pbc = global::Google.Protobuf.Collections;
|
||||
using pbr = global::Google.Protobuf.Reflection;
|
||||
using scg = global::System.Collections.Generic;
|
||||
/// <summary>Holder for reflection information generated from RpcInvocation.proto</summary>
|
||||
public static partial class RpcInvocationReflection {
|
||||
|
||||
#region Descriptor
|
||||
/// <summary>File descriptor for RpcInvocation.proto</summary>
|
||||
public static pbr::FileDescriptor Descriptor {
|
||||
get { return descriptor; }
|
||||
}
|
||||
private static pbr::FileDescriptor descriptor;
|
||||
|
||||
static RpcInvocationReflection() {
|
||||
byte[] descriptorData = global::System.Convert.FromBase64String(
|
||||
string.Concat(
|
||||
"ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l",
|
||||
"c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k",
|
||||
"EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u",
|
||||
"SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD",
|
||||
"IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF",
|
||||
"EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp",
|
||||
"dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY",
|
||||
"AiABKAlIAEIICgZvbmVvZl8iKgoNUGVyc29uTWVzc2FnZRIMCgROYW1lGAEg",
|
||||
"ASgJEgsKA0FnZRgCIAEoA2IGcHJvdG8z"));
|
||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
|
||||
new pbr::FileDescriptor[] { },
|
||||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
|
||||
new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null),
|
||||
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null),
|
||||
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null),
|
||||
new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null),
|
||||
new pbr::GeneratedClrTypeInfo(typeof(global::PersonMessage), global::PersonMessage.Parser, new[]{ "Name", "Age" }, null, null, null)
|
||||
}));
|
||||
}
|
||||
#endregion
|
||||
|
||||
}
|
||||
#region Messages
|
||||
public sealed partial class RpcMessageKind : pb::IMessage<RpcMessageKind> {
|
||||
private static readonly pb::MessageParser<RpcMessageKind> _parser = new pb::MessageParser<RpcMessageKind>(() => new RpcMessageKind());
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pb::MessageParser<RpcMessageKind> Parser { get { return _parser; } }
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pbr::MessageDescriptor Descriptor {
|
||||
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
pbr::MessageDescriptor pb::IMessage.Descriptor {
|
||||
get { return Descriptor; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcMessageKind() {
|
||||
OnConstruction();
|
||||
}
|
||||
|
||||
partial void OnConstruction();
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcMessageKind(RpcMessageKind other) : this() {
|
||||
messageKind_ = other.messageKind_;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcMessageKind Clone() {
|
||||
return new RpcMessageKind(this);
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "MessageKind" field.</summary>
|
||||
public const int MessageKindFieldNumber = 1;
|
||||
private global::RpcMessageKind.Types.Kind messageKind_ = 0;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public global::RpcMessageKind.Types.Kind MessageKind {
|
||||
get { return messageKind_; }
|
||||
set {
|
||||
messageKind_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override bool Equals(object other) {
|
||||
return Equals(other as RpcMessageKind);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool Equals(RpcMessageKind other) {
|
||||
if (ReferenceEquals(other, null)) {
|
||||
return false;
|
||||
}
|
||||
if (ReferenceEquals(other, this)) {
|
||||
return true;
|
||||
}
|
||||
if (MessageKind != other.MessageKind) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override int GetHashCode() {
|
||||
int hash = 1;
|
||||
if (MessageKind != 0) hash ^= MessageKind.GetHashCode();
|
||||
return hash;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override string ToString() {
|
||||
return pb::JsonFormatter.ToDiagnosticString(this);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void WriteTo(pb::CodedOutputStream output) {
|
||||
if (MessageKind != 0) {
|
||||
output.WriteRawTag(8);
|
||||
output.WriteEnum((int) MessageKind);
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int CalculateSize() {
|
||||
int size = 0;
|
||||
if (MessageKind != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(RpcMessageKind other) {
|
||||
if (other == null) {
|
||||
return;
|
||||
}
|
||||
if (other.MessageKind != 0) {
|
||||
MessageKind = other.MessageKind;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(pb::CodedInputStream input) {
|
||||
uint tag;
|
||||
while ((tag = input.ReadTag()) != 0) {
|
||||
switch(tag) {
|
||||
default:
|
||||
input.SkipLastField();
|
||||
break;
|
||||
case 8: {
|
||||
messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#region Nested types
|
||||
/// <summary>Container for nested types declared in the RpcMessageKind message type.</summary>
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static partial class Types {
|
||||
public enum Kind {
|
||||
[pbr::OriginalName("Result")] Result = 0,
|
||||
[pbr::OriginalName("Invocation")] Invocation = 1,
|
||||
}
|
||||
|
||||
}
|
||||
#endregion
|
||||
|
||||
}
|
||||
|
||||
public sealed partial class RpcInvocationHeader : pb::IMessage<RpcInvocationHeader> {
|
||||
private static readonly pb::MessageParser<RpcInvocationHeader> _parser = new pb::MessageParser<RpcInvocationHeader>(() => new RpcInvocationHeader());
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pb::MessageParser<RpcInvocationHeader> Parser { get { return _parser; } }
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pbr::MessageDescriptor Descriptor {
|
||||
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
pbr::MessageDescriptor pb::IMessage.Descriptor {
|
||||
get { return Descriptor; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationHeader() {
|
||||
OnConstruction();
|
||||
}
|
||||
|
||||
partial void OnConstruction();
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationHeader(RpcInvocationHeader other) : this() {
|
||||
name_ = other.name_;
|
||||
id_ = other.id_;
|
||||
numArgs_ = other.numArgs_;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationHeader Clone() {
|
||||
return new RpcInvocationHeader(this);
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Name" field.</summary>
|
||||
public const int NameFieldNumber = 1;
|
||||
private string name_ = "";
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public string Name {
|
||||
get { return name_; }
|
||||
set {
|
||||
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Id" field.</summary>
|
||||
public const int IdFieldNumber = 2;
|
||||
private int id_;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int Id {
|
||||
get { return id_; }
|
||||
set {
|
||||
id_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "NumArgs" field.</summary>
|
||||
public const int NumArgsFieldNumber = 3;
|
||||
private int numArgs_;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int NumArgs {
|
||||
get { return numArgs_; }
|
||||
set {
|
||||
numArgs_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override bool Equals(object other) {
|
||||
return Equals(other as RpcInvocationHeader);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool Equals(RpcInvocationHeader other) {
|
||||
if (ReferenceEquals(other, null)) {
|
||||
return false;
|
||||
}
|
||||
if (ReferenceEquals(other, this)) {
|
||||
return true;
|
||||
}
|
||||
if (Name != other.Name) return false;
|
||||
if (Id != other.Id) return false;
|
||||
if (NumArgs != other.NumArgs) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override int GetHashCode() {
|
||||
int hash = 1;
|
||||
if (Name.Length != 0) hash ^= Name.GetHashCode();
|
||||
if (Id != 0) hash ^= Id.GetHashCode();
|
||||
if (NumArgs != 0) hash ^= NumArgs.GetHashCode();
|
||||
return hash;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override string ToString() {
|
||||
return pb::JsonFormatter.ToDiagnosticString(this);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void WriteTo(pb::CodedOutputStream output) {
|
||||
if (Name.Length != 0) {
|
||||
output.WriteRawTag(10);
|
||||
output.WriteString(Name);
|
||||
}
|
||||
if (Id != 0) {
|
||||
output.WriteRawTag(16);
|
||||
output.WriteInt32(Id);
|
||||
}
|
||||
if (NumArgs != 0) {
|
||||
output.WriteRawTag(24);
|
||||
output.WriteInt32(NumArgs);
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int CalculateSize() {
|
||||
int size = 0;
|
||||
if (Name.Length != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
|
||||
}
|
||||
if (Id != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
|
||||
}
|
||||
if (NumArgs != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumArgs);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(RpcInvocationHeader other) {
|
||||
if (other == null) {
|
||||
return;
|
||||
}
|
||||
if (other.Name.Length != 0) {
|
||||
Name = other.Name;
|
||||
}
|
||||
if (other.Id != 0) {
|
||||
Id = other.Id;
|
||||
}
|
||||
if (other.NumArgs != 0) {
|
||||
NumArgs = other.NumArgs;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(pb::CodedInputStream input) {
|
||||
uint tag;
|
||||
while ((tag = input.ReadTag()) != 0) {
|
||||
switch(tag) {
|
||||
default:
|
||||
input.SkipLastField();
|
||||
break;
|
||||
case 10: {
|
||||
Name = input.ReadString();
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
Id = input.ReadInt32();
|
||||
break;
|
||||
}
|
||||
case 24: {
|
||||
NumArgs = input.ReadInt32();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public sealed partial class RpcInvocationResultHeader : pb::IMessage<RpcInvocationResultHeader> {
|
||||
private static readonly pb::MessageParser<RpcInvocationResultHeader> _parser = new pb::MessageParser<RpcInvocationResultHeader>(() => new RpcInvocationResultHeader());
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pb::MessageParser<RpcInvocationResultHeader> Parser { get { return _parser; } }
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pbr::MessageDescriptor Descriptor {
|
||||
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
pbr::MessageDescriptor pb::IMessage.Descriptor {
|
||||
get { return Descriptor; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationResultHeader() {
|
||||
OnConstruction();
|
||||
}
|
||||
|
||||
partial void OnConstruction();
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() {
|
||||
id_ = other.id_;
|
||||
hasResult_ = other.hasResult_;
|
||||
error_ = other.error_;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public RpcInvocationResultHeader Clone() {
|
||||
return new RpcInvocationResultHeader(this);
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Id" field.</summary>
|
||||
public const int IdFieldNumber = 1;
|
||||
private int id_;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int Id {
|
||||
get { return id_; }
|
||||
set {
|
||||
id_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "HasResult" field.</summary>
|
||||
public const int HasResultFieldNumber = 2;
|
||||
private bool hasResult_;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool HasResult {
|
||||
get { return hasResult_; }
|
||||
set {
|
||||
hasResult_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Error" field.</summary>
|
||||
public const int ErrorFieldNumber = 3;
|
||||
private string error_ = "";
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public string Error {
|
||||
get { return error_; }
|
||||
set {
|
||||
error_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override bool Equals(object other) {
|
||||
return Equals(other as RpcInvocationResultHeader);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool Equals(RpcInvocationResultHeader other) {
|
||||
if (ReferenceEquals(other, null)) {
|
||||
return false;
|
||||
}
|
||||
if (ReferenceEquals(other, this)) {
|
||||
return true;
|
||||
}
|
||||
if (Id != other.Id) return false;
|
||||
if (HasResult != other.HasResult) return false;
|
||||
if (Error != other.Error) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override int GetHashCode() {
|
||||
int hash = 1;
|
||||
if (Id != 0) hash ^= Id.GetHashCode();
|
||||
if (HasResult != false) hash ^= HasResult.GetHashCode();
|
||||
if (Error.Length != 0) hash ^= Error.GetHashCode();
|
||||
return hash;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override string ToString() {
|
||||
return pb::JsonFormatter.ToDiagnosticString(this);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void WriteTo(pb::CodedOutputStream output) {
|
||||
if (Id != 0) {
|
||||
output.WriteRawTag(8);
|
||||
output.WriteInt32(Id);
|
||||
}
|
||||
if (HasResult != false) {
|
||||
output.WriteRawTag(16);
|
||||
output.WriteBool(HasResult);
|
||||
}
|
||||
if (Error.Length != 0) {
|
||||
output.WriteRawTag(26);
|
||||
output.WriteString(Error);
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int CalculateSize() {
|
||||
int size = 0;
|
||||
if (Id != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
|
||||
}
|
||||
if (HasResult != false) {
|
||||
size += 1 + 1;
|
||||
}
|
||||
if (Error.Length != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(Error);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(RpcInvocationResultHeader other) {
|
||||
if (other == null) {
|
||||
return;
|
||||
}
|
||||
if (other.Id != 0) {
|
||||
Id = other.Id;
|
||||
}
|
||||
if (other.HasResult != false) {
|
||||
HasResult = other.HasResult;
|
||||
}
|
||||
if (other.Error.Length != 0) {
|
||||
Error = other.Error;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(pb::CodedInputStream input) {
|
||||
uint tag;
|
||||
while ((tag = input.ReadTag()) != 0) {
|
||||
switch(tag) {
|
||||
default:
|
||||
input.SkipLastField();
|
||||
break;
|
||||
case 8: {
|
||||
Id = input.ReadInt32();
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
HasResult = input.ReadBool();
|
||||
break;
|
||||
}
|
||||
case 26: {
|
||||
Error = input.ReadString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public sealed partial class PrimitiveValue : pb::IMessage<PrimitiveValue> {
|
||||
private static readonly pb::MessageParser<PrimitiveValue> _parser = new pb::MessageParser<PrimitiveValue>(() => new PrimitiveValue());
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pb::MessageParser<PrimitiveValue> Parser { get { return _parser; } }
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pbr::MessageDescriptor Descriptor {
|
||||
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
pbr::MessageDescriptor pb::IMessage.Descriptor {
|
||||
get { return Descriptor; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PrimitiveValue() {
|
||||
OnConstruction();
|
||||
}
|
||||
|
||||
partial void OnConstruction();
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PrimitiveValue(PrimitiveValue other) : this() {
|
||||
switch (other.OneofCase) {
|
||||
case OneofOneofCase.Int32Value:
|
||||
Int32Value = other.Int32Value;
|
||||
break;
|
||||
case OneofOneofCase.StringValue:
|
||||
StringValue = other.StringValue;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PrimitiveValue Clone() {
|
||||
return new PrimitiveValue(this);
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Int32Value" field.</summary>
|
||||
public const int Int32ValueFieldNumber = 1;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int Int32Value {
|
||||
get { return oneofCase_ == OneofOneofCase.Int32Value ? (int) oneof_ : 0; }
|
||||
set {
|
||||
oneof_ = value;
|
||||
oneofCase_ = OneofOneofCase.Int32Value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "StringValue" field.</summary>
|
||||
public const int StringValueFieldNumber = 2;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public string StringValue {
|
||||
get { return oneofCase_ == OneofOneofCase.StringValue ? (string) oneof_ : ""; }
|
||||
set {
|
||||
oneof_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
|
||||
oneofCase_ = OneofOneofCase.StringValue;
|
||||
}
|
||||
}
|
||||
|
||||
private object oneof_;
|
||||
/// <summary>Enum of possible cases for the "oneof_" oneof.</summary>
|
||||
public enum OneofOneofCase {
|
||||
None = 0,
|
||||
Int32Value = 1,
|
||||
StringValue = 2,
|
||||
}
|
||||
private OneofOneofCase oneofCase_ = OneofOneofCase.None;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public OneofOneofCase OneofCase {
|
||||
get { return oneofCase_; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void ClearOneof() {
|
||||
oneofCase_ = OneofOneofCase.None;
|
||||
oneof_ = null;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override bool Equals(object other) {
|
||||
return Equals(other as PrimitiveValue);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool Equals(PrimitiveValue other) {
|
||||
if (ReferenceEquals(other, null)) {
|
||||
return false;
|
||||
}
|
||||
if (ReferenceEquals(other, this)) {
|
||||
return true;
|
||||
}
|
||||
if (Int32Value != other.Int32Value) return false;
|
||||
if (StringValue != other.StringValue) return false;
|
||||
if (OneofCase != other.OneofCase) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override int GetHashCode() {
|
||||
int hash = 1;
|
||||
if (oneofCase_ == OneofOneofCase.Int32Value) hash ^= Int32Value.GetHashCode();
|
||||
if (oneofCase_ == OneofOneofCase.StringValue) hash ^= StringValue.GetHashCode();
|
||||
hash ^= (int) oneofCase_;
|
||||
return hash;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override string ToString() {
|
||||
return pb::JsonFormatter.ToDiagnosticString(this);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void WriteTo(pb::CodedOutputStream output) {
|
||||
if (oneofCase_ == OneofOneofCase.Int32Value) {
|
||||
output.WriteRawTag(8);
|
||||
output.WriteInt32(Int32Value);
|
||||
}
|
||||
if (oneofCase_ == OneofOneofCase.StringValue) {
|
||||
output.WriteRawTag(18);
|
||||
output.WriteString(StringValue);
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int CalculateSize() {
|
||||
int size = 0;
|
||||
if (oneofCase_ == OneofOneofCase.Int32Value) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Int32Value);
|
||||
}
|
||||
if (oneofCase_ == OneofOneofCase.StringValue) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(StringValue);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(PrimitiveValue other) {
|
||||
if (other == null) {
|
||||
return;
|
||||
}
|
||||
switch (other.OneofCase) {
|
||||
case OneofOneofCase.Int32Value:
|
||||
Int32Value = other.Int32Value;
|
||||
break;
|
||||
case OneofOneofCase.StringValue:
|
||||
StringValue = other.StringValue;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(pb::CodedInputStream input) {
|
||||
uint tag;
|
||||
while ((tag = input.ReadTag()) != 0) {
|
||||
switch(tag) {
|
||||
default:
|
||||
input.SkipLastField();
|
||||
break;
|
||||
case 8: {
|
||||
Int32Value = input.ReadInt32();
|
||||
break;
|
||||
}
|
||||
case 18: {
|
||||
StringValue = input.ReadString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public sealed partial class PersonMessage : pb::IMessage<PersonMessage> {
|
||||
private static readonly pb::MessageParser<PersonMessage> _parser = new pb::MessageParser<PersonMessage>(() => new PersonMessage());
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pb::MessageParser<PersonMessage> Parser { get { return _parser; } }
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public static pbr::MessageDescriptor Descriptor {
|
||||
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[4]; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
pbr::MessageDescriptor pb::IMessage.Descriptor {
|
||||
get { return Descriptor; }
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PersonMessage() {
|
||||
OnConstruction();
|
||||
}
|
||||
|
||||
partial void OnConstruction();
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PersonMessage(PersonMessage other) : this() {
|
||||
name_ = other.name_;
|
||||
age_ = other.age_;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public PersonMessage Clone() {
|
||||
return new PersonMessage(this);
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Name" field.</summary>
|
||||
public const int NameFieldNumber = 1;
|
||||
private string name_ = "";
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public string Name {
|
||||
get { return name_; }
|
||||
set {
|
||||
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Field number for the "Age" field.</summary>
|
||||
public const int AgeFieldNumber = 2;
|
||||
private long age_;
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public long Age {
|
||||
get { return age_; }
|
||||
set {
|
||||
age_ = value;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override bool Equals(object other) {
|
||||
return Equals(other as PersonMessage);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public bool Equals(PersonMessage other) {
|
||||
if (ReferenceEquals(other, null)) {
|
||||
return false;
|
||||
}
|
||||
if (ReferenceEquals(other, this)) {
|
||||
return true;
|
||||
}
|
||||
if (Name != other.Name) return false;
|
||||
if (Age != other.Age) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override int GetHashCode() {
|
||||
int hash = 1;
|
||||
if (Name.Length != 0) hash ^= Name.GetHashCode();
|
||||
if (Age != 0L) hash ^= Age.GetHashCode();
|
||||
return hash;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public override string ToString() {
|
||||
return pb::JsonFormatter.ToDiagnosticString(this);
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void WriteTo(pb::CodedOutputStream output) {
|
||||
if (Name.Length != 0) {
|
||||
output.WriteRawTag(10);
|
||||
output.WriteString(Name);
|
||||
}
|
||||
if (Age != 0L) {
|
||||
output.WriteRawTag(16);
|
||||
output.WriteInt64(Age);
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public int CalculateSize() {
|
||||
int size = 0;
|
||||
if (Name.Length != 0) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
|
||||
}
|
||||
if (Age != 0L) {
|
||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(Age);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(PersonMessage other) {
|
||||
if (other == null) {
|
||||
return;
|
||||
}
|
||||
if (other.Name.Length != 0) {
|
||||
Name = other.Name;
|
||||
}
|
||||
if (other.Age != 0L) {
|
||||
Age = other.Age;
|
||||
}
|
||||
}
|
||||
|
||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
|
||||
public void MergeFrom(pb::CodedInputStream input) {
|
||||
uint tag;
|
||||
while ((tag = input.ReadTag()) != 0) {
|
||||
switch(tag) {
|
||||
default:
|
||||
input.SkipLastField();
|
||||
break;
|
||||
case 10: {
|
||||
Name = input.ReadString();
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
Age = input.ReadInt64();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
|
||||
#endregion Designer generated code
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
syntax = "proto3";
|
||||
|
||||
message RpcMessageKind {
|
||||
enum Kind { Result = 0; Invocation = 1; }
|
||||
Kind MessageKind = 1;
|
||||
}
|
||||
|
||||
message RpcInvocationHeader {
|
||||
string Name = 1;
|
||||
int32 Id = 2;
|
||||
int32 NumArgs = 3;
|
||||
}
|
||||
|
||||
message RpcInvocationResultHeader {
|
||||
int32 Id = 1;
|
||||
bool HasResult = 2;
|
||||
string Error = 3;
|
||||
}
|
||||
|
||||
message PrimitiveValue {
|
||||
oneof oneof_ {
|
||||
int32 Int32Value = 1;
|
||||
string StringValue = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message PersonMessage {
|
||||
string Name = 1;
|
||||
int64 Age = 2;
|
||||
}
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Google.Protobuf;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class ProtobufSerializer
|
||||
{
|
||||
public object GetValue(CodedInputStream inputStream, Type type)
|
||||
{
|
||||
if (type == typeof(Person))
|
||||
{
|
||||
var value = new PersonMessage();
|
||||
inputStream.ReadMessage(value);
|
||||
|
||||
return new Person { Name = value.Name, Age = value.Age };
|
||||
}
|
||||
|
||||
throw new InvalidOperationException("(Deserialize) Unknown type.");
|
||||
}
|
||||
|
||||
public IMessage GetMessage(object value)
|
||||
{
|
||||
Person person = value as Person;
|
||||
if (person != null)
|
||||
{
|
||||
return new PersonMessage { Name = person.Name, Age = person.Age };
|
||||
}
|
||||
|
||||
throw new InvalidOperationException("(Serialize) Unknown type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,12 +3,10 @@
|
|||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using SocketsSample.EndPoints;
|
||||
using SocketsSample.Hubs;
|
||||
using SocketsSample.Protobuf;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
|
|
@ -18,21 +16,12 @@ namespace SocketsSample
|
|||
// For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940
|
||||
public void ConfigureServices(IServiceCollection services)
|
||||
{
|
||||
services.AddSingleton<ProtobufInvocationAdapter>();
|
||||
services.AddSingleton<LineInvocationAdapter>();
|
||||
|
||||
services.AddSockets();
|
||||
|
||||
services.AddSignalR(options =>
|
||||
{
|
||||
options.RegisterInvocationAdapter<ProtobufInvocationAdapter>("protobuf");
|
||||
options.RegisterInvocationAdapter<LineInvocationAdapter>("line");
|
||||
});
|
||||
services.AddSignalR();
|
||||
// .AddRedis();
|
||||
|
||||
services.AddEndPoint<MessagesEndPoint>();
|
||||
|
||||
services.AddSingleton<ProtobufSerializer>();
|
||||
}
|
||||
|
||||
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Net.Http;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Client;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client
|
||||
{
|
||||
|
|
@ -22,17 +23,14 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IConnection _connection;
|
||||
private readonly IInvocationAdapter _adapter;
|
||||
private readonly IHubProtocol _protocol;
|
||||
private readonly HubBinder _binder;
|
||||
|
||||
private HttpClient _httpClient;
|
||||
|
||||
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
|
||||
|
||||
// We need to ensure pending calls added after a connection failure don't hang. Right now the easiest thing to do is lock.
|
||||
private readonly object _pendingCallsLock = new object();
|
||||
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
|
||||
private readonly Dictionary<string, InvocationRequest> _pendingCalls = new Dictionary<string, InvocationRequest>();
|
||||
|
||||
private readonly ConcurrentDictionary<string, InvocationHandler> _handlers = new ConcurrentDictionary<string, InvocationHandler>();
|
||||
|
||||
private int _nextId = 0;
|
||||
|
|
@ -50,27 +48,33 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
}
|
||||
|
||||
public HubConnection(Uri url)
|
||||
: this(new Connection(url), new JsonNetInvocationAdapter(), null)
|
||||
: this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), null)
|
||||
{ }
|
||||
|
||||
public HubConnection(Uri url, ILoggerFactory loggerFactory)
|
||||
: this(new Connection(url), new JsonNetInvocationAdapter(), loggerFactory)
|
||||
: this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), loggerFactory)
|
||||
{ }
|
||||
|
||||
public HubConnection(Uri url, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
|
||||
: this(new Connection(url, loggerFactory), adapter, loggerFactory)
|
||||
// These are only really needed for tests now...
|
||||
public HubConnection(IConnection connection, ILoggerFactory loggerFactory)
|
||||
: this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory)
|
||||
{ }
|
||||
|
||||
public HubConnection(IConnection connection, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
|
||||
public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory)
|
||||
{
|
||||
if (connection == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(connection));
|
||||
}
|
||||
|
||||
if (protocol == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(protocol));
|
||||
}
|
||||
|
||||
_connection = connection;
|
||||
_binder = new HubBinder(this);
|
||||
_adapter = adapter;
|
||||
_protocol = protocol;
|
||||
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
|
||||
_logger = _loggerFactory.CreateLogger<HubConnection>();
|
||||
_connection.Received += OnDataReceived;
|
||||
|
|
@ -100,6 +104,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
public async Task DisposeAsync()
|
||||
{
|
||||
await _connection.DisposeAsync();
|
||||
|
||||
_httpClient?.Dispose();
|
||||
}
|
||||
|
||||
|
|
@ -118,81 +123,80 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
public Task<object> Invoke(string methodName, Type returnType, params object[] args) => Invoke(methodName, returnType, CancellationToken.None, args);
|
||||
public async Task<object> Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
|
||||
{
|
||||
_logger.LogTrace("Preparing invocation of '{0}', with return type '{1}' and {2} args", methodName, returnType.AssemblyQualifiedName, args.Length);
|
||||
ThrowIfConnectionTerminated();
|
||||
_logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, returnType.AssemblyQualifiedName, args.Length);
|
||||
|
||||
// Create an invocation descriptor.
|
||||
var descriptor = new InvocationDescriptor
|
||||
{
|
||||
Id = GetNextId(),
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
// Create an invocation descriptor. Client invocations are always blocking
|
||||
var invocationMessage = new InvocationMessage(GetNextId(), nonBlocking: false, target: methodName, arguments: args);
|
||||
|
||||
// I just want an excuse to use 'irq' as a variable name...
|
||||
_logger.LogDebug("Registering Invocation ID '{0}' for tracking", descriptor.Id);
|
||||
var irq = new InvocationRequest(cancellationToken, returnType);
|
||||
_logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId);
|
||||
var irq = new InvocationRequest(cancellationToken, returnType, invocationMessage.InvocationId, _loggerFactory);
|
||||
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
if (_connectionActive.IsCancellationRequested)
|
||||
{
|
||||
throw new InvalidOperationException("Connection has been terminated.");
|
||||
}
|
||||
_pendingCalls.Add(descriptor.Id, irq);
|
||||
}
|
||||
AddInvocation(irq);
|
||||
|
||||
// Trace the invocation, but only if that logging level is enabled (because building the args list is a bit slow)
|
||||
// Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow)
|
||||
if (_logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
var argsList = string.Join(", ", args.Select(a => a.GetType().FullName));
|
||||
_logger.LogTrace("Invocation #{0}: {1} {2}({3})", descriptor.Id, returnType.FullName, methodName, argsList);
|
||||
_logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, returnType.FullName, methodName, argsList);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var ms = new MemoryStream();
|
||||
await _adapter.WriteMessageAsync(descriptor, ms, cancellationToken);
|
||||
var payload = await _protocol.WriteToArrayAsync(invocationMessage);
|
||||
|
||||
_logger.LogInformation("Sending Invocation #{0}", descriptor.Id);
|
||||
_logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId);
|
||||
|
||||
// TODO: Format.Text - who, where and when decides about the format of outgoing messages
|
||||
await _connection.SendAsync(ms.ToArray(), MessageType.Text, cancellationToken);
|
||||
_logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id);
|
||||
await _connection.SendAsync(payload, _protocol.MessageType, cancellationToken);
|
||||
_logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(0, ex, "Sending Invocation #{0} failed", descriptor.Id);
|
||||
irq.Completion.TrySetException(ex);
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
_pendingCalls.Remove(descriptor.Id);
|
||||
}
|
||||
_logger.LogError(0, ex, "Sending Invocation '{invocationId}' failed", invocationMessage.InvocationId);
|
||||
irq.Fail(ex);
|
||||
TryRemoveInvocation(invocationMessage.InvocationId, out _);
|
||||
}
|
||||
|
||||
// Return the completion task. It will be completed by ReceiveMessages when the response is received.
|
||||
return await irq.Completion.Task;
|
||||
return await irq.Completion;
|
||||
}
|
||||
|
||||
private async void OnDataReceived(byte[] data, MessageType messageType)
|
||||
private void OnDataReceived(byte[] data, MessageType messageType)
|
||||
{
|
||||
var message
|
||||
= await _adapter.ReadMessageAsync(new MemoryStream(data), _binder, _connectionActive.Token);
|
||||
var message = _protocol.ParseMessage(data, _binder);
|
||||
|
||||
InvocationRequest irq;
|
||||
switch (message)
|
||||
{
|
||||
case InvocationDescriptor invocationDescriptor:
|
||||
DispatchInvocation(invocationDescriptor, _connectionActive.Token);
|
||||
break;
|
||||
case InvocationResultDescriptor invocationResultDescriptor:
|
||||
InvocationRequest irq;
|
||||
lock (_pendingCallsLock)
|
||||
case InvocationMessage invocation:
|
||||
if (_logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
_connectionActive.Token.ThrowIfCancellationRequested();
|
||||
irq = _pendingCalls[invocationResultDescriptor.Id];
|
||||
_pendingCalls.Remove(invocationResultDescriptor.Id);
|
||||
var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName));
|
||||
_logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList);
|
||||
}
|
||||
DispatchInvocationResult(invocationResultDescriptor, irq, _connectionActive.Token);
|
||||
DispatchInvocation(invocation, _connectionActive.Token);
|
||||
break;
|
||||
case CompletionMessage completion:
|
||||
if (!TryRemoveInvocation(completion.InvocationId, out irq))
|
||||
{
|
||||
_logger.LogWarning("Dropped unsolicited Completion message for invocation '{invocationId}'", completion.InvocationId);
|
||||
return;
|
||||
}
|
||||
DispatchInvocationCompletion(completion, irq);
|
||||
irq.Dispose();
|
||||
break;
|
||||
case StreamItemMessage streamItem:
|
||||
// Complete the invocation with an error, we don't support streaming (yet)
|
||||
if (!TryRemoveInvocation(streamItem.InvocationId, out irq))
|
||||
{
|
||||
_logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId);
|
||||
return;
|
||||
}
|
||||
irq.Fail(new NotSupportedException("Streaming method results are not supported"));
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -201,74 +205,118 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
_logger.LogTrace("Shutting down connection");
|
||||
if (ex != null)
|
||||
{
|
||||
_logger.LogError("Connection is shutting down due to an error: {0}", ex);
|
||||
_logger.LogError(ex, "Connection is shutting down due to an error");
|
||||
}
|
||||
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
// We cancel inside the lock to make sure everyone who was part-way through registering an invocation
|
||||
// completes. This also ensures that nobody will add things to _pendingCalls after we leave this block
|
||||
// because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock)
|
||||
_connectionActive.Cancel();
|
||||
foreach (var call in _pendingCalls.Values)
|
||||
|
||||
foreach (var outstandingCall in _pendingCalls.Values)
|
||||
{
|
||||
_logger.LogTrace("Removing pending call {invocationId}", outstandingCall.InvocationId);
|
||||
if (ex != null)
|
||||
{
|
||||
call.Completion.TrySetException(ex);
|
||||
}
|
||||
else
|
||||
{
|
||||
call.Completion.TrySetCanceled();
|
||||
outstandingCall.Fail(ex);
|
||||
}
|
||||
outstandingCall.Dispose();
|
||||
}
|
||||
_pendingCalls.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken)
|
||||
private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken)
|
||||
{
|
||||
// Find the handler
|
||||
if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler))
|
||||
if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler))
|
||||
{
|
||||
_logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method);
|
||||
_logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Return values
|
||||
// TODO: Dispatch to a sync context to ensure we aren't blocking this loop.
|
||||
handler.Handler(invocationDescriptor.Arguments);
|
||||
handler.Handler(invocation.Arguments);
|
||||
}
|
||||
|
||||
private void DispatchInvocationResult(InvocationResultDescriptor result, InvocationRequest irq, CancellationToken cancellationToken)
|
||||
private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq)
|
||||
{
|
||||
_logger.LogInformation("Received Result for Invocation #{0}", result.Id);
|
||||
_logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId);
|
||||
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
if (irq.CancellationToken.IsCancellationRequested)
|
||||
{
|
||||
return;
|
||||
_logger.LogTrace("Cancelling dispatch of Completion message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId);
|
||||
}
|
||||
|
||||
Debug.Assert(irq.Completion != null, "Didn't properly capture InvocationRequest in callback for ReadInvocationResultDescriptorAsync");
|
||||
|
||||
// If the invocation hasn't been cancelled, dispatch the result
|
||||
if (!irq.CancellationToken.IsCancellationRequested)
|
||||
else
|
||||
{
|
||||
irq.Registration.Dispose();
|
||||
|
||||
// Complete the request based on the result
|
||||
// TODO: the TrySetXYZ methods will cause continuations attached to the Task to run, so we should dispatch to a sync context or thread pool.
|
||||
if (!string.IsNullOrEmpty(result.Error))
|
||||
if (!string.IsNullOrEmpty(completion.Error))
|
||||
{
|
||||
_logger.LogInformation("Completing Invocation #{0} with error: {1}", result.Id, result.Error);
|
||||
irq.Completion.TrySetException(new Exception(result.Error));
|
||||
irq.Fail(new HubException(completion.Error));
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.LogInformation("Completing Invocation #{0} with result of type: {1}", result.Id, result.Result?.GetType()?.FullName ?? "<<void>>");
|
||||
irq.Completion.TrySetResult(result.Result);
|
||||
irq.Complete(completion.Result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void ThrowIfConnectionTerminated()
|
||||
{
|
||||
if (_connectionActive.Token.IsCancellationRequested)
|
||||
{
|
||||
_logger.LogError("Invoke was called after the connection was terminated");
|
||||
throw new InvalidOperationException("Connection has been terminated.");
|
||||
}
|
||||
}
|
||||
|
||||
private string GetNextId() => Interlocked.Increment(ref _nextId).ToString();
|
||||
|
||||
private void AddInvocation(InvocationRequest irq)
|
||||
{
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
ThrowIfConnectionTerminated();
|
||||
if (_pendingCalls.ContainsKey(irq.InvocationId))
|
||||
{
|
||||
_logger.LogCritical("Invocation ID '{invocationId}' is already in use.", irq.InvocationId);
|
||||
throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use.");
|
||||
}
|
||||
else
|
||||
{
|
||||
_pendingCalls.Add(irq.InvocationId, irq);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private bool TryGetInvocation(string invocationId, out InvocationRequest irq)
|
||||
{
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
ThrowIfConnectionTerminated();
|
||||
return _pendingCalls.TryGetValue(invocationId, out irq);
|
||||
}
|
||||
}
|
||||
|
||||
private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq)
|
||||
{
|
||||
lock (_pendingCallsLock)
|
||||
{
|
||||
ThrowIfConnectionTerminated();
|
||||
if (_pendingCalls.TryGetValue(invocationId, out irq))
|
||||
{
|
||||
_pendingCalls.Remove(invocationId);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class HubBinder : IInvocationBinder
|
||||
{
|
||||
private HubConnection _connection;
|
||||
|
|
@ -282,7 +330,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
{
|
||||
if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq))
|
||||
{
|
||||
_connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId);
|
||||
_connection._logger.LogError("Unsolicited response received for invocation '{invocationId}'", invocationId);
|
||||
return null;
|
||||
}
|
||||
return irq.ResultType;
|
||||
|
|
@ -292,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
{
|
||||
if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler))
|
||||
{
|
||||
_connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName);
|
||||
_connection._logger.LogWarning("Failed to find handler for '{target}' method", methodName);
|
||||
return Type.EmptyTypes;
|
||||
}
|
||||
return handler.ParameterTypes;
|
||||
|
|
@ -311,20 +359,51 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
}
|
||||
}
|
||||
|
||||
private struct InvocationRequest
|
||||
private class InvocationRequest : IDisposable
|
||||
{
|
||||
private readonly TaskCompletionSource<object> _completionSource = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
private readonly CancellationTokenRegistration _cancellationTokenRegistration;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public Type ResultType { get; }
|
||||
public CancellationToken CancellationToken { get; }
|
||||
public CancellationTokenRegistration Registration { get; }
|
||||
public TaskCompletionSource<object> Completion { get; }
|
||||
public string InvocationId { get; }
|
||||
|
||||
public InvocationRequest(CancellationToken cancellationToken, Type resultType)
|
||||
public Task<object> Completion => _completionSource.Task;
|
||||
|
||||
|
||||
public InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
Completion = tcs;
|
||||
_logger = loggerFactory.CreateLogger<InvocationRequest>();
|
||||
_cancellationTokenRegistration = cancellationToken.Register(() => _completionSource.TrySetCanceled());
|
||||
|
||||
InvocationId = invocationId;
|
||||
CancellationToken = cancellationToken;
|
||||
Registration = cancellationToken.Register(() => tcs.TrySetCanceled());
|
||||
ResultType = resultType;
|
||||
|
||||
_logger.LogTrace("Invocation {invocationId} created", InvocationId);
|
||||
}
|
||||
|
||||
public void Fail(Exception exception)
|
||||
{
|
||||
_logger.LogTrace("Invocation {invocationId} marked as failed", InvocationId);
|
||||
_completionSource.TrySetException(exception);
|
||||
}
|
||||
|
||||
public void Complete(object result)
|
||||
{
|
||||
_logger.LogTrace("Invocation {invocationId} marked as completed", InvocationId);
|
||||
_completionSource.TrySetResult(result);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_logger.LogTrace("Invocation {invocationId} disposed", InvocationId);
|
||||
|
||||
// Just in case it hasn't already been completed
|
||||
_completionSource.TrySetCanceled();
|
||||
|
||||
_cancellationTokenRegistration.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client
|
||||
{
|
||||
[Serializable]
|
||||
public class HubException : Exception
|
||||
{
|
||||
public HubException()
|
||||
{
|
||||
}
|
||||
|
||||
public HubException(string message) : base(message)
|
||||
{
|
||||
}
|
||||
|
||||
public HubException(string message, Exception innerException) : base(message, innerException)
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public interface IInvocationAdapter
|
||||
{
|
||||
Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken);
|
||||
|
||||
Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken);
|
||||
}
|
||||
}
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public interface IInvocationBinder
|
||||
{
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public class CompletionMessage : HubMessage
|
||||
{
|
||||
public string Error { get; }
|
||||
public object Result { get; }
|
||||
public bool HasResult { get; }
|
||||
|
||||
public CompletionMessage(string invocationId, string error, object result, bool hasResult) : base(invocationId)
|
||||
{
|
||||
if (error != null && result != null)
|
||||
{
|
||||
throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both");
|
||||
}
|
||||
Error = error;
|
||||
Result = result;
|
||||
HasResult = hasResult;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
var errorStr = Error == null ? "<<null>>" : $"\"{Error}\"";
|
||||
var resultField = HasResult ? $", {nameof(Result)}: {Result ?? "<<null>>"}" : string.Empty;
|
||||
return $"Completion {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Error)}: {errorStr}{resultField} }}";
|
||||
}
|
||||
|
||||
// Static factory methods. Don't want to use constructor overloading because it will break down
|
||||
// if you need to send a payload statically-typed as a string. And because a static factory is clearer here
|
||||
public static CompletionMessage WithError(string invocationId, string error) => new CompletionMessage(invocationId, error, result: null, hasResult: false);
|
||||
|
||||
public static CompletionMessage WithResult(string invocationId, object payload) => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public abstract class HubMessage
|
||||
{
|
||||
public string InvocationId { get; }
|
||||
|
||||
protected HubMessage(string invocationId)
|
||||
{
|
||||
InvocationId = invocationId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.IO.Pipelines.Text.Primitives;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public static class HubProtocolWriteMessageExtensions
|
||||
{
|
||||
public static async ValueTask<byte[]> WriteToArrayAsync(this IHubProtocol protocol, HubMessage message)
|
||||
{
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
var pipe = memoryStream.AsPipelineWriter();
|
||||
|
||||
// See https://github.com/dotnet/corefxlab/issues/1460, the TextEncoder is unimportant but required.
|
||||
var output = new PipelineTextOutput(pipe, TextEncoder.Utf8);
|
||||
|
||||
// Encode the message
|
||||
if (!protocol.TryWriteMessage(message, output))
|
||||
{
|
||||
throw new InvalidOperationException("Failed to write message to the output stream");
|
||||
}
|
||||
|
||||
await output.FlushAsync();
|
||||
|
||||
// Create a message
|
||||
return memoryStream.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public interface IHubProtocol
|
||||
{
|
||||
MessageType MessageType { get; }
|
||||
|
||||
HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder);
|
||||
|
||||
bool TryWriteMessage(HubMessage message, IOutput output);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public class InvocationMessage : HubMessage
|
||||
{
|
||||
public string Target { get; }
|
||||
|
||||
public object[] Arguments { get; }
|
||||
|
||||
public bool NonBlocking { get; }
|
||||
|
||||
public InvocationMessage(string invocationId, bool nonBlocking, string target, params object[] arguments) : base(invocationId)
|
||||
{
|
||||
if (string.IsNullOrEmpty(invocationId))
|
||||
{
|
||||
throw new ArgumentNullException(nameof(invocationId));
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(target))
|
||||
{
|
||||
throw new ArgumentNullException(nameof(target));
|
||||
}
|
||||
|
||||
if (arguments == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(arguments));
|
||||
}
|
||||
|
||||
Target = target;
|
||||
Arguments = arguments;
|
||||
NonBlocking = nonBlocking;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"Invocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(NonBlocking)}: {NonBlocking}, {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments.Select(a => a?.ToString()))} ] }}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.IO;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public class JsonHubProtocol : IHubProtocol
|
||||
{
|
||||
private const string ResultPropertyName = "result";
|
||||
private const string InvocationIdPropertyName = "invocationId";
|
||||
private const string TypePropertyName = "type";
|
||||
private const string ErrorPropertyName = "error";
|
||||
private const string TargetPropertyName = "target";
|
||||
private const string NonBlockingPropertyName = "nonBlocking";
|
||||
private const string ArgumentsPropertyName = "arguments";
|
||||
|
||||
private const int InvocationMessageType = 1;
|
||||
private const int ResultMessageType = 2;
|
||||
private const int CompletionMessageType = 3;
|
||||
|
||||
// ONLY to be used for application payloads (args, return values, etc.)
|
||||
private JsonSerializer _payloadSerializer;
|
||||
|
||||
public MessageType MessageType => MessageType.Text;
|
||||
|
||||
/// <summary>
|
||||
/// Creates an instance of the <see cref="JsonHubProtocol"/> using the specified <see cref="JsonSerializer"/>
|
||||
/// to serialize application payloads (arguments, results, etc.). The serialization of the outer protocol can
|
||||
/// NOT be changed using this serializer.
|
||||
/// </summary>
|
||||
/// <param name="payloadSerializer">The <see cref="JsonSerializer"/> to use to serialize application payloads (arguments, results, etc.).</param>
|
||||
public JsonHubProtocol(JsonSerializer payloadSerializer)
|
||||
{
|
||||
if (payloadSerializer == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(payloadSerializer));
|
||||
}
|
||||
|
||||
_payloadSerializer = payloadSerializer;
|
||||
}
|
||||
|
||||
public HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder)
|
||||
{
|
||||
// TODO: Need a span-native JSON parser!
|
||||
using (var memoryStream = new MemoryStream(input.ToArray()))
|
||||
{
|
||||
return ParseMessage(memoryStream, binder);
|
||||
}
|
||||
}
|
||||
|
||||
public bool TryWriteMessage(HubMessage message, IOutput output)
|
||||
{
|
||||
// TODO: Need IOutput-compatible JSON serializer!
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
WriteMessage(message, memoryStream);
|
||||
memoryStream.Flush();
|
||||
|
||||
return output.TryWrite(memoryStream.ToArray());
|
||||
}
|
||||
}
|
||||
|
||||
private HubMessage ParseMessage(Stream input, IInvocationBinder binder)
|
||||
{
|
||||
using (var reader = new JsonTextReader(new StreamReader(input)))
|
||||
{
|
||||
try
|
||||
{
|
||||
// PERF: Could probably use the JsonTextReader directly for better perf and fewer allocations
|
||||
var token = JToken.ReadFrom(reader);
|
||||
if (token == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (token.Type != JTokenType.Object)
|
||||
{
|
||||
throw new FormatException($"Unexpected JSON Token Type '{token.Type}'. Expected a JSON Object.");
|
||||
}
|
||||
|
||||
var json = (JObject)token;
|
||||
|
||||
// Determine the type of the message
|
||||
var type = GetRequiredProperty<int>(json, TypePropertyName, JTokenType.Integer);
|
||||
switch (type)
|
||||
{
|
||||
case InvocationMessageType:
|
||||
return BindInvocationMessage(json, binder);
|
||||
case ResultMessageType:
|
||||
return BindResultMessage(json, binder);
|
||||
case CompletionMessageType:
|
||||
return BindCompletionMessage(json, binder);
|
||||
default:
|
||||
throw new FormatException($"Unknown message type: {type}");
|
||||
}
|
||||
}
|
||||
catch (JsonReaderException jrex)
|
||||
{
|
||||
throw new FormatException("Error reading JSON.", jrex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void WriteMessage(HubMessage message, Stream stream)
|
||||
{
|
||||
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
|
||||
{
|
||||
switch (message)
|
||||
{
|
||||
case InvocationMessage m:
|
||||
WriteInvocationMessage(m, writer);
|
||||
break;
|
||||
case StreamItemMessage m:
|
||||
WriteResultMessage(m, writer);
|
||||
break;
|
||||
case CompletionMessage m:
|
||||
WriteCompletionMessage(m, writer);
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter writer)
|
||||
{
|
||||
writer.WriteStartObject();
|
||||
WriteHubMessageCommon(message, writer, CompletionMessageType);
|
||||
if (!string.IsNullOrEmpty(message.Error))
|
||||
{
|
||||
writer.WritePropertyName(ErrorPropertyName);
|
||||
writer.WriteValue(message.Error);
|
||||
}
|
||||
else if (message.HasResult)
|
||||
{
|
||||
writer.WritePropertyName(ResultPropertyName);
|
||||
_payloadSerializer.Serialize(writer, message.Result);
|
||||
}
|
||||
writer.WriteEndObject();
|
||||
}
|
||||
|
||||
private void WriteResultMessage(StreamItemMessage message, JsonTextWriter writer)
|
||||
{
|
||||
writer.WriteStartObject();
|
||||
WriteHubMessageCommon(message, writer, ResultMessageType);
|
||||
writer.WritePropertyName(ResultPropertyName);
|
||||
_payloadSerializer.Serialize(writer, message.Item);
|
||||
writer.WriteEndObject();
|
||||
}
|
||||
|
||||
private void WriteInvocationMessage(InvocationMessage message, JsonTextWriter writer)
|
||||
{
|
||||
writer.WriteStartObject();
|
||||
WriteHubMessageCommon(message, writer, InvocationMessageType);
|
||||
writer.WritePropertyName(TargetPropertyName);
|
||||
writer.WriteValue(message.Target);
|
||||
|
||||
if (message.NonBlocking)
|
||||
{
|
||||
writer.WritePropertyName(NonBlockingPropertyName);
|
||||
writer.WriteValue(message.NonBlocking);
|
||||
}
|
||||
|
||||
writer.WritePropertyName(ArgumentsPropertyName);
|
||||
writer.WriteStartArray();
|
||||
foreach (var argument in message.Arguments)
|
||||
{
|
||||
_payloadSerializer.Serialize(writer, argument);
|
||||
}
|
||||
writer.WriteEndArray();
|
||||
|
||||
writer.WriteEndObject();
|
||||
}
|
||||
|
||||
private static void WriteHubMessageCommon(HubMessage message, JsonTextWriter writer, int type)
|
||||
{
|
||||
writer.WritePropertyName(InvocationIdPropertyName);
|
||||
writer.WriteValue(message.InvocationId);
|
||||
writer.WritePropertyName(TypePropertyName);
|
||||
writer.WriteValue(type);
|
||||
}
|
||||
|
||||
private InvocationMessage BindInvocationMessage(JObject json, IInvocationBinder binder)
|
||||
{
|
||||
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
|
||||
var target = GetRequiredProperty<string>(json, TargetPropertyName, JTokenType.String);
|
||||
var nonBlocking = GetOptionalProperty<bool>(json, NonBlockingPropertyName, JTokenType.Boolean);
|
||||
|
||||
var args = GetRequiredProperty<JArray>(json, ArgumentsPropertyName, JTokenType.Array);
|
||||
|
||||
var paramTypes = binder.GetParameterTypes(target);
|
||||
var arguments = new object[args.Count];
|
||||
if (paramTypes.Length != arguments.Length)
|
||||
{
|
||||
throw new FormatException($"Invocation provides {arguments.Length} argument(s) but target expects {paramTypes.Length}.");
|
||||
}
|
||||
|
||||
for (var i = 0; i < paramTypes.Length; i++)
|
||||
{
|
||||
var paramType = paramTypes[i];
|
||||
|
||||
// TODO(anurse): We can add some DI magic here to allow users to provide their own serialization
|
||||
// Related Bug: https://github.com/aspnet/SignalR/issues/261
|
||||
arguments[i] = args[i].ToObject(paramType, _payloadSerializer);
|
||||
}
|
||||
|
||||
return new InvocationMessage(invocationId, nonBlocking, target, arguments);
|
||||
}
|
||||
|
||||
private StreamItemMessage BindResultMessage(JObject json, IInvocationBinder binder)
|
||||
{
|
||||
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
|
||||
var result = GetRequiredProperty<JToken>(json, ResultPropertyName);
|
||||
|
||||
var returnType = binder.GetReturnType(invocationId);
|
||||
return new StreamItemMessage(invocationId, result?.ToObject(returnType, _payloadSerializer));
|
||||
}
|
||||
|
||||
private CompletionMessage BindCompletionMessage(JObject json, IInvocationBinder binder)
|
||||
{
|
||||
var invocationId = GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
|
||||
var error = GetOptionalProperty<string>(json, ErrorPropertyName, JTokenType.String);
|
||||
var resultProp = json.Property(ResultPropertyName);
|
||||
|
||||
if (error != null && resultProp != null)
|
||||
{
|
||||
throw new FormatException("The 'error' and 'result' properties are mutually exclusive.");
|
||||
}
|
||||
|
||||
if (resultProp == null)
|
||||
{
|
||||
return new CompletionMessage(invocationId, error, result: null, hasResult: false);
|
||||
}
|
||||
else
|
||||
{
|
||||
var returnType = binder.GetReturnType(invocationId);
|
||||
var payload = resultProp.Value?.ToObject(returnType, _payloadSerializer);
|
||||
return new CompletionMessage(invocationId, error, result: payload, hasResult: true);
|
||||
}
|
||||
}
|
||||
|
||||
private T GetOptionalProperty<T>(JObject json, string property, JTokenType expectedType = JTokenType.None, T defaultValue = default(T))
|
||||
{
|
||||
var prop = json[property];
|
||||
|
||||
if (prop == null)
|
||||
{
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
return GetValue<T>(property, expectedType, prop);
|
||||
}
|
||||
|
||||
private T GetRequiredProperty<T>(JObject json, string property, JTokenType expectedType = JTokenType.None)
|
||||
{
|
||||
var prop = json[property];
|
||||
|
||||
if (prop == null)
|
||||
{
|
||||
throw new FormatException($"Missing required property '{property}'.");
|
||||
}
|
||||
|
||||
return GetValue<T>(property, expectedType, prop);
|
||||
}
|
||||
|
||||
private static T GetValue<T>(string property, JTokenType expectedType, JToken prop)
|
||||
{
|
||||
if (expectedType != JTokenType.None && prop.Type != expectedType)
|
||||
{
|
||||
throw new FormatException($"Expected '{property}' to be of type {expectedType}.");
|
||||
}
|
||||
return prop.Value<T>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public class StreamItemMessage : HubMessage
|
||||
{
|
||||
public object Item { get; }
|
||||
|
||||
public StreamItemMessage(string invocationId, object item) : base(invocationId)
|
||||
{
|
||||
Item = item;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"StreamItem {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Item)}: {Item ?? "<<null>>"} }}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public static class InvocationAdapterExtensions
|
||||
{
|
||||
public static Task<InvocationMessage> ReadMessageAsync(this IInvocationAdapter self, Stream stream, IInvocationBinder binder) => self.ReadMessageAsync(stream, binder, CancellationToken.None);
|
||||
|
||||
public static Task WriteMessageAsync(this IInvocationAdapter self, InvocationMessage message, Stream stream) => self.WriteMessageAsync(message, stream, CancellationToken.None);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class InvocationDescriptor : InvocationMessage
|
||||
{
|
||||
public string Method { get; set; }
|
||||
|
||||
public object[] Arguments { get; set; }
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class InvocationResultDescriptor : InvocationMessage
|
||||
{
|
||||
public object Result { get; set; }
|
||||
|
||||
public string Error { get; set; }
|
||||
}
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class JsonNetInvocationAdapter : IInvocationAdapter
|
||||
{
|
||||
private JsonSerializer _serializer = new JsonSerializer();
|
||||
|
||||
public JsonNetInvocationAdapter()
|
||||
{
|
||||
}
|
||||
|
||||
public Task<InvocationMessage> ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
|
||||
{
|
||||
var reader = new JsonTextReader(new StreamReader(stream));
|
||||
// REVIEW: Task.Run()
|
||||
return Task.Run<InvocationMessage>(() =>
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
var json = _serializer.Deserialize<JObject>(reader);
|
||||
if (json == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
// Determine the type of the message
|
||||
if (json["Result"] != null)
|
||||
{
|
||||
// It's a result
|
||||
return BindInvocationResultDescriptor(json, binder, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
return BindInvocationDescriptor(json, binder, cancellationToken);
|
||||
}
|
||||
}, cancellationToken);
|
||||
}
|
||||
|
||||
public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
var writer = new JsonTextWriter(new StreamWriter(stream));
|
||||
_serializer.Serialize(writer, message);
|
||||
writer.Flush();
|
||||
return TaskCache.CompletedTask;
|
||||
}
|
||||
|
||||
private InvocationDescriptor BindInvocationDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken)
|
||||
{
|
||||
var invocation = new InvocationDescriptor
|
||||
{
|
||||
Id = json.Value<string>("Id"),
|
||||
Method = json.Value<string>("Method"),
|
||||
};
|
||||
|
||||
var paramTypes = binder.GetParameterTypes(invocation.Method);
|
||||
invocation.Arguments = new object[paramTypes.Length];
|
||||
|
||||
var args = json.Value<JArray>("Arguments");
|
||||
for (var i = 0; i < paramTypes.Length; i++)
|
||||
{
|
||||
var paramType = paramTypes[i];
|
||||
invocation.Arguments[i] = args[i].ToObject(paramType, _serializer);
|
||||
}
|
||||
|
||||
return invocation;
|
||||
}
|
||||
|
||||
private InvocationResultDescriptor BindInvocationResultDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken)
|
||||
{
|
||||
var id = json.Value<string>("Id");
|
||||
var returnType = binder.GetReturnType(id);
|
||||
var result = new InvocationResultDescriptor()
|
||||
{
|
||||
Id = id,
|
||||
Result = returnType == null ? null : json["Result"].ToObject(returnType, _serializer),
|
||||
Error = json.Value<string>("Error")
|
||||
};
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -9,11 +9,20 @@
|
|||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
<PackageTags>aspnetcore;signalr</PackageTags>
|
||||
<EnableApiCheck>false</EnableApiCheck>
|
||||
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<Compile Include="../Common/IOutputExtensions.cs" Link="IOutputExtensions.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Newtonsoft.Json" Version="$(JsonNetVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR\Microsoft.AspNetCore.SignalR.csproj" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="StackExchange.Redis.StrongName" Version="$(RedisVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
|||
|
|
@ -5,53 +5,80 @@ using System;
|
|||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Newtonsoft.Json;
|
||||
using StackExchange.Redis;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis
|
||||
{
|
||||
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
|
||||
{
|
||||
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
|
||||
|
||||
private readonly ConnectionList _connections = new ConnectionList();
|
||||
// TODO: Investigate "memory leak" entries never get removed
|
||||
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
private readonly ConnectionMultiplexer _redisServerConnection;
|
||||
private readonly ISubscriber _bus;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly ILogger _logger;
|
||||
private readonly RedisOptions _options;
|
||||
|
||||
public RedisHubLifetimeManager(InvocationAdapterRegistry registry,
|
||||
ILoggerFactory loggerFactory,
|
||||
// This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
|
||||
private readonly JsonSerializer _serializer = new JsonSerializer
|
||||
{
|
||||
// We need to serialize objects "full-fidelity", even if it is noisy, so we preserve the original types
|
||||
TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
|
||||
TypeNameHandling = TypeNameHandling.All,
|
||||
Formatting = Formatting.None
|
||||
};
|
||||
|
||||
private long _nextInvocationId = 0;
|
||||
|
||||
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
|
||||
IOptions<RedisOptions> options)
|
||||
{
|
||||
_loggerFactory = loggerFactory;
|
||||
_registry = registry;
|
||||
_logger = logger;
|
||||
_options = options.Value;
|
||||
|
||||
var writer = new LoggerTextWriter(loggerFactory.CreateLogger<RedisHubLifetimeManager<THub>>());
|
||||
var writer = new LoggerTextWriter(logger);
|
||||
_logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e))));
|
||||
_redisServerConnection = _options.Connect(writer);
|
||||
if (_redisServerConnection.IsConnected)
|
||||
{
|
||||
_logger.LogInformation("Connected to redis");
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: We could support reconnecting, like old SignalR does.
|
||||
throw new InvalidOperationException("Connection to redis failed.");
|
||||
}
|
||||
_bus = _redisServerConnection.GetSubscriber();
|
||||
|
||||
var previousBroadcastTask = TaskCache.CompletedTask;
|
||||
var previousBroadcastTask = Task.CompletedTask;
|
||||
|
||||
_bus.Subscribe(typeof(THub).FullName, async (c, data) =>
|
||||
var channelName = typeof(THub).FullName;
|
||||
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
await previousBroadcastTask;
|
||||
|
||||
_logger.LogTrace("Received message from redis channel {channel}", channelName);
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
|
||||
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, data));
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
|
||||
previousBroadcastTask = Task.WhenAll(tasks);
|
||||
|
|
@ -60,80 +87,68 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
public override Task InvokeAllAsync(string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName, message);
|
||||
}
|
||||
|
||||
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
|
||||
}
|
||||
|
||||
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
|
||||
}
|
||||
|
||||
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
|
||||
}
|
||||
|
||||
private async Task PublishAsync(string channel, InvocationDescriptor message)
|
||||
private async Task PublishAsync(string channel, HubMessage hubMessage)
|
||||
{
|
||||
// TODO: What format??
|
||||
var invocationAdapter = _registry.GetInvocationAdapter("json");
|
||||
|
||||
// BAD
|
||||
using (var ms = new MemoryStream())
|
||||
byte[] payload;
|
||||
using (var stream = new MemoryStream())
|
||||
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
|
||||
{
|
||||
await invocationAdapter.WriteMessageAsync(message, ms);
|
||||
|
||||
await _bus.PublishAsync(channel, ms.ToArray());
|
||||
_serializer.Serialize(writer, hubMessage);
|
||||
await writer.FlushAsync();
|
||||
payload = stream.ToArray();
|
||||
}
|
||||
|
||||
_logger.LogTrace("Publishing message to redis channel {channel}", channel);
|
||||
await _bus.PublishAsync(channel, payload);
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
|
||||
var connectionTask = TaskCache.CompletedTask;
|
||||
var userTask = TaskCache.CompletedTask;
|
||||
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
|
||||
var connectionTask = Task.CompletedTask;
|
||||
var userTask = Task.CompletedTask;
|
||||
|
||||
_connections.Add(connection);
|
||||
|
||||
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
|
||||
redisSubscriptions.Add(connectionChannel);
|
||||
|
||||
var previousConnectionTask = TaskCache.CompletedTask;
|
||||
var previousConnectionTask = Task.CompletedTask;
|
||||
|
||||
_logger.LogInformation("Subscribing to connection channel: {channel}", connectionChannel);
|
||||
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
|
||||
{
|
||||
await previousConnectionTask;
|
||||
|
||||
previousConnectionTask = WriteAsync(connection, data);
|
||||
var message = DeserializeMessage(data);
|
||||
|
||||
previousConnectionTask = WriteAsync(connection, message);
|
||||
});
|
||||
|
||||
|
||||
|
|
@ -142,14 +157,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
var previousUserTask = TaskCache.CompletedTask;
|
||||
var previousUserTask = Task.CompletedTask;
|
||||
|
||||
// TODO: Look at optimizing (looping over connections checking for Name)
|
||||
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
|
||||
{
|
||||
await previousUserTask;
|
||||
|
||||
previousUserTask = WriteAsync(connection, data);
|
||||
var message = DeserializeMessage(data);
|
||||
|
||||
previousUserTask = WriteAsync(connection, message);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -162,16 +179,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
var tasks = new List<Task>();
|
||||
|
||||
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>("redis_subscriptions");
|
||||
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>(RedisSubscriptionsMetadataName);
|
||||
if (redisSubscriptions != null)
|
||||
{
|
||||
foreach (var subscription in redisSubscriptions)
|
||||
{
|
||||
_logger.LogInformation("Unsubscribing from channel: {channel}", subscription);
|
||||
tasks.Add(_bus.UnsubscribeAsync(subscription));
|
||||
}
|
||||
}
|
||||
|
||||
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
|
||||
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
|
||||
|
||||
if (groupNames != null)
|
||||
{
|
||||
|
|
@ -190,7 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
var groupChannel = typeof(THub).FullName + ".group." + groupName;
|
||||
|
||||
var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet<string>());
|
||||
var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
|
||||
|
||||
lock (groupNames)
|
||||
{
|
||||
|
|
@ -210,8 +228,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return;
|
||||
}
|
||||
|
||||
var previousTask = TaskCache.CompletedTask;
|
||||
var previousTask = Task.CompletedTask;
|
||||
|
||||
_logger.LogInformation("Subscribing to group channel: {channel}", groupChannel);
|
||||
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
|
||||
{
|
||||
// Since this callback is async, we await the previous task then
|
||||
|
|
@ -219,10 +238,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
// want to do concurrent writes to the outgoing connections
|
||||
await previousTask;
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
|
||||
var tasks = new List<Task>(group.Connections.Count);
|
||||
foreach (var groupConnection in group.Connections)
|
||||
{
|
||||
tasks.Add(WriteAsync(groupConnection, data));
|
||||
tasks.Add(WriteAsync(groupConnection, message));
|
||||
}
|
||||
|
||||
previousTask = Task.WhenAll(tasks);
|
||||
|
|
@ -244,7 +265,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return;
|
||||
}
|
||||
|
||||
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
|
||||
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
|
||||
if (groupNames != null)
|
||||
{
|
||||
lock (groupNames)
|
||||
|
|
@ -260,6 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
if (group.Connections.Count == 0)
|
||||
{
|
||||
_logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel);
|
||||
await _bus.UnsubscribeAsync(groupChannel);
|
||||
}
|
||||
}
|
||||
|
|
@ -275,9 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
_redisServerConnection.Dispose();
|
||||
}
|
||||
|
||||
private async Task WriteAsync(Connection connection, byte[] data)
|
||||
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
|
||||
{
|
||||
var message = new Message(data, MessageType.Text, endOfMessage: true);
|
||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||
var data = await protocol.WriteToArrayAsync(hubMessage);
|
||||
var message = new Message(data, protocol.MessageType, endOfMessage: true);
|
||||
|
||||
while (await connection.Transport.Output.WaitToWriteAsync())
|
||||
{
|
||||
|
|
@ -288,6 +312,23 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
private string GetInvocationId()
|
||||
{
|
||||
var invocationId = Interlocked.Increment(ref _nextInvocationId);
|
||||
return invocationId.ToString();
|
||||
}
|
||||
|
||||
private HubMessage DeserializeMessage(RedisValue data)
|
||||
{
|
||||
HubMessage message;
|
||||
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
|
||||
{
|
||||
message = (HubMessage)_serializer.Deserialize(reader);
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
private class LoggerTextWriter : TextWriter
|
||||
{
|
||||
private readonly ILogger _logger;
|
||||
|
|
|
|||
|
|
@ -3,43 +3,37 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
||||
{
|
||||
private long _nextInvocationId = 0;
|
||||
private readonly ConnectionList _connections = new ConnectionList();
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
|
||||
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
|
||||
{
|
||||
_registry = registry;
|
||||
}
|
||||
|
||||
public override Task AddGroupAsync(Connection connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
|
||||
var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
|
||||
|
||||
lock (groups)
|
||||
{
|
||||
groups.Add(groupName);
|
||||
}
|
||||
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task RemoveGroupAsync(Connection connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
||||
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
|
||||
|
||||
if (groups == null)
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
lock (groups)
|
||||
|
|
@ -47,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
groups.Remove(groupName);
|
||||
}
|
||||
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task InvokeAllAsync(string methodName, object[] args)
|
||||
|
|
@ -58,11 +52,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
|
||||
{
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
// TODO: serialize once per format by providing a different stream?
|
||||
foreach (var connection in _connections)
|
||||
|
|
@ -72,9 +62,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
continue;
|
||||
}
|
||||
|
||||
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
|
||||
|
||||
tasks.Add(WriteAsync(connection, invocationAdapter, message));
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
|
|
@ -84,22 +72,16 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
var connection = _connections[connectionId];
|
||||
|
||||
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return WriteAsync(connection, invocationAdapter, message);
|
||||
return WriteAsync(connection, message);
|
||||
}
|
||||
|
||||
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
||||
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
|
||||
return groups?.Contains(groupName) == true;
|
||||
});
|
||||
}
|
||||
|
|
@ -115,21 +97,20 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
_connections.Add(connection);
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task OnDisconnectedAsync(Connection connection)
|
||||
{
|
||||
_connections.Remove(connection);
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private static async Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocation)
|
||||
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
|
||||
{
|
||||
var stream = new MemoryStream();
|
||||
await invocationAdapter.WriteMessageAsync(invocation, stream);
|
||||
|
||||
var message = new Message(stream.ToArray(), MessageType.Text, endOfMessage: true);
|
||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||
var payload = await protocol.WriteToArrayAsync(hubMessage);
|
||||
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
|
||||
|
||||
while (await connection.Transport.Output.WaitToWriteAsync())
|
||||
{
|
||||
|
|
@ -139,5 +120,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private string GetInvocationId()
|
||||
{
|
||||
var invocationId = Interlocked.Increment(ref _nextInvocationId);
|
||||
return invocationId.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
|
|
@ -62,12 +61,12 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public virtual Task OnConnectedAsync()
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public virtual Task OnDisconnectedAsync(Exception exception)
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public abstract class InvocationMessage
|
||||
public static class HubConnectionMetadataNames
|
||||
{
|
||||
public string Id { get; set; }
|
||||
public static readonly string HubProtocol = nameof(HubProtocol);
|
||||
public static readonly string Groups = nameof(Groups);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +1,14 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
|
@ -19,12 +20,12 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public class HubEndPoint<THub> : HubEndPoint<THub, IClientProxy> where THub : Hub<IClientProxy>
|
||||
{
|
||||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
IHubProtocolResolver protocolResolver,
|
||||
IHubContext<THub> hubContext,
|
||||
InvocationAdapterRegistry registry,
|
||||
IOptions<EndPointOptions<HubEndPoint<THub, IClientProxy>>> endPointOptions,
|
||||
ILogger<HubEndPoint<THub>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
: base(lifetimeManager, hubContext, registry, endPointOptions, logger, serviceScopeFactory)
|
||||
: base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
@ -36,19 +37,19 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
private readonly IHubContext<THub, TClient> _hubContext;
|
||||
private readonly ILogger<HubEndPoint<THub, TClient>> _logger;
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
private readonly IServiceScopeFactory _serviceScopeFactory;
|
||||
private readonly IHubProtocolResolver _protocolResolver;
|
||||
|
||||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
IHubProtocolResolver protocolResolver,
|
||||
IHubContext<THub, TClient> hubContext,
|
||||
InvocationAdapterRegistry registry,
|
||||
IOptions<EndPointOptions<HubEndPoint<THub, TClient>>> endPointOptions,
|
||||
ILogger<HubEndPoint<THub, TClient>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
{
|
||||
_protocolResolver = protocolResolver;
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_hubContext = hubContext;
|
||||
_registry = registry;
|
||||
_logger = logger;
|
||||
_serviceScopeFactory = serviceScopeFactory;
|
||||
|
||||
|
|
@ -59,6 +60,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
try
|
||||
{
|
||||
// Resolve the Hub Protocol for the connection and store it in metadata
|
||||
// Other components, outside the Hub, may need to know what protocol is in use
|
||||
// for a particular connection, so we store it here.
|
||||
connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection);
|
||||
|
||||
await _lifetimeManager.OnConnectedAsync(connection);
|
||||
await RunHubAsync(connection);
|
||||
}
|
||||
|
|
@ -140,14 +146,13 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
private async Task DispatchMessagesAsync(Connection connection)
|
||||
{
|
||||
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
|
||||
|
||||
// We use these for error handling. Since we dispatch multiple hub invocations
|
||||
// in parallel, we need a way to communicate failure back to the main processing loop. The
|
||||
// cancellation token is used to stop reading from the channel, the tcs
|
||||
// is used to get the exception so we can bubble it up the stack
|
||||
var cts = new CancellationTokenSource();
|
||||
var tcs = new TaskCompletionSource<object>();
|
||||
var completion = new TaskCompletionSource<object>();
|
||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||
|
||||
try
|
||||
{
|
||||
|
|
@ -155,100 +160,94 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
while (connection.Transport.Input.TryRead(out var incomingMessage))
|
||||
{
|
||||
InvocationDescriptor invocationDescriptor;
|
||||
var inputStream = new MemoryStream(incomingMessage.Payload);
|
||||
var hubMessage = protocol.ParseMessage(incomingMessage.Payload, this);
|
||||
|
||||
// TODO: Handle receiving InvocationResultDescriptor
|
||||
invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor;
|
||||
|
||||
// Is there a better way of detecting that a connection was closed?
|
||||
if (invocationDescriptor == null)
|
||||
switch (hubMessage)
|
||||
{
|
||||
break;
|
||||
}
|
||||
case InvocationMessage invocationMessage:
|
||||
if (_logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
_logger.LogDebug("Received hub invocation: {invocation}", invocationMessage);
|
||||
}
|
||||
|
||||
if (_logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
_logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor);
|
||||
}
|
||||
// Don't wait on the result of execution, continue processing other
|
||||
// incoming messages on this connection.
|
||||
var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion);
|
||||
break;
|
||||
|
||||
// Don't wait on the result of execution, continue processing other
|
||||
// incoming messages on this connection.
|
||||
var ignore = ProcessInvocation(connection, invocationAdapter, invocationDescriptor, cts, tcs);
|
||||
// Other kind of message we weren't expecting
|
||||
default:
|
||||
_logger.LogError("Received unsupported message of type '{messageType}'", hubMessage.GetType().FullName);
|
||||
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Await the task so the exception bubbles up to the caller
|
||||
await tcs.Task;
|
||||
await completion.Task;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ProcessInvocation(Connection connection,
|
||||
IInvocationAdapter invocationAdapter,
|
||||
InvocationDescriptor invocationDescriptor,
|
||||
CancellationTokenSource cts,
|
||||
TaskCompletionSource<object> tcs)
|
||||
IHubProtocol protocol,
|
||||
InvocationMessage invocationMessage,
|
||||
CancellationTokenSource dispatcherCancellation,
|
||||
TaskCompletionSource<object> dispatcherCompletion)
|
||||
{
|
||||
try
|
||||
{
|
||||
// If an unexpected exception occurs then we want to kill the entire connection
|
||||
// by ending the processing loop
|
||||
await Execute(connection, invocationAdapter, invocationDescriptor);
|
||||
await Execute(connection, protocol, invocationMessage);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
// Set the exception on the task completion source
|
||||
tcs.TrySetException(ex);
|
||||
dispatcherCompletion.TrySetException(ex);
|
||||
|
||||
// Cancel reading operation
|
||||
cts.Cancel();
|
||||
dispatcherCancellation.Cancel();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor)
|
||||
private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
||||
{
|
||||
InvocationResultDescriptor result;
|
||||
HubMethodDescriptor descriptor;
|
||||
if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor))
|
||||
if (!_methods.TryGetValue(invocationMessage.Target, out descriptor))
|
||||
{
|
||||
result = await Invoke(descriptor, connection, invocationDescriptor);
|
||||
// Send an error to the client. Then let the normal completion process occur
|
||||
_logger.LogError("Unknown hub method '{method}'", invocationMessage.Target);
|
||||
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
|
||||
}
|
||||
else
|
||||
{
|
||||
// If there's no method then return a failed response for this request
|
||||
result = new InvocationResultDescriptor
|
||||
{
|
||||
Id = invocationDescriptor.Id,
|
||||
Error = $"Unknown hub method '{invocationDescriptor.Method}'"
|
||||
};
|
||||
|
||||
_logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method);
|
||||
}
|
||||
|
||||
// TODO: Pool memory
|
||||
var outStream = new MemoryStream();
|
||||
await invocationAdapter.WriteMessageAsync(result, outStream);
|
||||
|
||||
var outMessage = new Message(outStream.ToArray(), MessageType.Text, endOfMessage: true);
|
||||
|
||||
while (await connection.Transport.Output.WaitToWriteAsync())
|
||||
{
|
||||
if (connection.Transport.Output.TryWrite(outMessage))
|
||||
{
|
||||
break;
|
||||
}
|
||||
var result = await Invoke(descriptor, connection, invocationMessage);
|
||||
await SendMessageAsync(connection, protocol, result);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<InvocationResultDescriptor> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor)
|
||||
private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage)
|
||||
{
|
||||
var invocationResult = new InvocationResultDescriptor
|
||||
{
|
||||
Id = invocationDescriptor.Id
|
||||
};
|
||||
var payload = await protocol.WriteToArrayAsync(hubMessage);
|
||||
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
|
||||
|
||||
while (await connection.Transport.Output.WaitToWriteAsync())
|
||||
{
|
||||
if (connection.Transport.Output.TryWrite(message))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Output is closed. Cancel this invocation completely
|
||||
_logger.LogWarning("Outbound channel was closed while trying to write hub message");
|
||||
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
|
||||
}
|
||||
|
||||
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage)
|
||||
{
|
||||
var methodExecutor = descriptor.MethodExecutor;
|
||||
|
||||
using (var scope = _serviceScopeFactory.CreateScope())
|
||||
|
|
@ -265,37 +264,35 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
if (methodExecutor.MethodReturnType == typeof(Task))
|
||||
{
|
||||
await (Task)methodExecutor.Execute(hub, invocationDescriptor.Arguments);
|
||||
await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = await methodExecutor.ExecuteAsync(hub, invocationDescriptor.Arguments);
|
||||
result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result = methodExecutor.Execute(hub, invocationDescriptor.Arguments);
|
||||
result = methodExecutor.Execute(hub, invocationMessage.Arguments);
|
||||
}
|
||||
|
||||
invocationResult.Result = result;
|
||||
return CompletionMessage.WithResult(invocationMessage.InvocationId, result);
|
||||
}
|
||||
catch (TargetInvocationException ex)
|
||||
{
|
||||
_logger.LogError(0, ex, "Failed to invoke hub method");
|
||||
invocationResult.Error = ex.InnerException.Message;
|
||||
return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(0, ex, "Failed to invoke hub method");
|
||||
invocationResult.Error = ex.Message;
|
||||
return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message);
|
||||
}
|
||||
finally
|
||||
{
|
||||
hubActivator.Release(hub);
|
||||
}
|
||||
}
|
||||
|
||||
return invocationResult;
|
||||
}
|
||||
|
||||
private void InitializeHub(THub hub, Connection connection)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public class DefaultHubProtocolResolver : IHubProtocolResolver
|
||||
{
|
||||
public IHubProtocol GetProtocol(Connection connection)
|
||||
{
|
||||
// TODO: Allow customization of this serializer!
|
||||
return new JsonHubProtocol(new JsonSerializer());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public interface IHubProtocolResolver
|
||||
{
|
||||
IHubProtocol GetProtocol(Connection connection);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class InvocationAdapterRegistry
|
||||
{
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private readonly SignalROptions _options;
|
||||
|
||||
public InvocationAdapterRegistry(IOptions<SignalROptions> options, IServiceProvider serviceProvider)
|
||||
{
|
||||
_options = options.Value;
|
||||
_serviceProvider = serviceProvider;
|
||||
}
|
||||
|
||||
public IInvocationAdapter GetInvocationAdapter(string format)
|
||||
{
|
||||
Type type;
|
||||
if (_options._invocationMappings.TryGetValue(format, out type))
|
||||
{
|
||||
return _serviceProvider.GetRequiredService(type) as IInvocationAdapter;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -14,7 +14,6 @@
|
|||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="Microsoft.Extensions.ClosedGenericMatcher.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="Microsoft.Extensions.ObjectMethodExecutor.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="$(JsonNetVersion)" />
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
|
||||
namespace Microsoft.Extensions.DependencyInjection
|
||||
{
|
||||
|
|
@ -13,26 +12,13 @@ namespace Microsoft.Extensions.DependencyInjection
|
|||
{
|
||||
services.AddSockets();
|
||||
services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>));
|
||||
services.AddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver));
|
||||
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
|
||||
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
|
||||
services.AddSingleton<IConfigureOptions<SignalROptions>, SignalROptionsSetup>();
|
||||
services.AddSingleton<JsonNetInvocationAdapter>();
|
||||
services.AddSingleton<InvocationAdapterRegistry>();
|
||||
services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>));
|
||||
services.AddRouting();
|
||||
|
||||
return new SignalRBuilder(services);
|
||||
}
|
||||
|
||||
public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action<SignalROptions> setupAction)
|
||||
{
|
||||
return services.AddSignalR().AddSignalROptions(setupAction);
|
||||
}
|
||||
|
||||
public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action<SignalROptions> setupAction)
|
||||
{
|
||||
builder.Services.Configure(setupAction);
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class SignalROptions
|
||||
{
|
||||
internal readonly Dictionary<string, Type> _invocationMappings = new Dictionary<string, Type>();
|
||||
|
||||
public void RegisterInvocationAdapter<TInvocationAdapter>(string format) where TInvocationAdapter : IInvocationAdapter
|
||||
{
|
||||
_invocationMappings[format] = typeof(TInvocationAdapter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace Microsoft.Extensions.DependencyInjection
|
||||
{
|
||||
public class SignalROptionsSetup : IConfigureOptions<SignalROptions>
|
||||
{
|
||||
public void Configure(SignalROptions options)
|
||||
{
|
||||
options.RegisterInvocationAdapter<JsonNetInvocationAdapter>("json");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -10,7 +10,6 @@ using System.Net.Http.Headers;
|
|||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
|
||||
|
|
@ -58,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
return t;
|
||||
}).Unwrap();
|
||||
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public async Task StopAsync()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="System.Text.Formatting" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
|
@ -8,7 +8,6 @@ using System.Net.Http.Headers;
|
|||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
|
||||
|
|
@ -61,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
return t;
|
||||
}).Unwrap();
|
||||
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task OpenConnection(IChannelConnection<SendMessage, Message> application, Uri url, CancellationToken cancellationToken)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
<EnableApiCheck>false</EnableApiCheck>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="../Common/IOutputExtensions.cs" Link="IOutputExtensions.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="System.Binary.Base64" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public static class ConnectionMetadataNames
|
||||
{
|
||||
public static readonly string Format = nameof(Format);
|
||||
public static readonly string Transport = nameof(Transport);
|
||||
public static readonly string HttpContext = nameof(HttpContext);
|
||||
}
|
||||
}
|
||||
|
|
@ -6,7 +6,7 @@ using Microsoft.AspNetCore.Sockets;
|
|||
|
||||
namespace Microsoft.Extensions.DependencyInjection
|
||||
{
|
||||
public static class EndpointDependencyInjectionExtensions
|
||||
public static class EndPointDependencyInjectionExtensions
|
||||
{
|
||||
public static IServiceCollection AddEndPoint<TEndPoint>(this IServiceCollection services) where TEndPoint : EndPoint
|
||||
{
|
||||
|
|
@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
|
||||
|
||||
state.Connection.Metadata["transport"] = TransportType.LongPolling;
|
||||
state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
|
||||
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
}
|
||||
|
|
@ -232,13 +232,10 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
private ConnectionState CreateConnection(HttpContext context)
|
||||
{
|
||||
var state = _manager.CreateConnection();
|
||||
var format = (string)context.Request.Query[ConnectionMetadataNames.Format];
|
||||
state.Connection.User = context.User;
|
||||
|
||||
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
|
||||
var formatType = (string)context.Request.Query["formatType"];
|
||||
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
|
||||
state.Connection.Metadata[typeof(HttpContext)] = context;
|
||||
|
||||
state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
|
||||
state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format;
|
||||
return state;
|
||||
}
|
||||
|
||||
|
|
@ -355,7 +352,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var messages = ParseSendBatch(ref reader, messageFormat);
|
||||
|
||||
// REVIEW: Do we want to return a specific status code here if the connection has ended?
|
||||
_logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count);
|
||||
_logger.LogDebug("Received batch of {count} message(s) in '/send'", messages.Count);
|
||||
foreach (var message in messages)
|
||||
{
|
||||
while (!state.Application.Output.TryWrite(message))
|
||||
|
|
@ -379,11 +376,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
connectionState.Connection.User = context.User;
|
||||
|
||||
var transport = connectionState.Connection.Metadata.Get<TransportType?>("transport");
|
||||
var transport = connectionState.Connection.Metadata.Get<TransportType?>(ConnectionMetadataNames.Transport);
|
||||
|
||||
if (transport == null)
|
||||
{
|
||||
connectionState.Connection.Metadata["transport"] = transportType;
|
||||
connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
|
||||
}
|
||||
else if (transport != transportType)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
|
@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
Task disposeTask = TaskCache.CompletedTask;
|
||||
Task disposeTask = Task.CompletedTask;
|
||||
|
||||
try
|
||||
{
|
||||
|
|
@ -69,8 +69,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
Connection.Dispose();
|
||||
Application.Dispose();
|
||||
|
||||
var applicationTask = ApplicationTask ?? TaskCache.CompletedTask;
|
||||
var transportTask = TransportTask ?? TaskCache.CompletedTask;
|
||||
var applicationTask = ApplicationTask ?? Task.CompletedTask;
|
||||
var transportTask = TransportTask ?? Task.CompletedTask;
|
||||
|
||||
disposeTask = WaitOnTasks(applicationTask, transportTask);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="System.Reflection.TypeExtensions" Version="$(CoreFxVersion)" />
|
||||
<PackageReference Include="System.Security.Claims" Version="$(CoreFxVersion)" />
|
||||
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
// Is this a frame we care about?
|
||||
if (!frame.Opcode.IsMessage())
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
LogFrame("Receiving", frame);
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.Extensions.WebSockets.Internal
|
||||
{
|
||||
|
|
@ -135,7 +134,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
|
|||
connection.ExecuteAsync((frame, _) =>
|
||||
{
|
||||
messageHandler(frame);
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}, null);
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -149,7 +148,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
|
|||
connection.ExecuteAsync((frame, s) =>
|
||||
{
|
||||
messageHandler(frame, s);
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}, state);
|
||||
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
<PackageReference Include="System.ValueTuple" Version="$(CoreFxVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests.Common
|
||||
|
|
@ -10,36 +11,48 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common
|
|||
{
|
||||
private const int DefaultTimeout = 5000;
|
||||
|
||||
public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout)
|
||||
public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
|
||||
{
|
||||
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
|
||||
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
|
||||
}
|
||||
|
||||
public static async Task OrTimeout(this Task task, TimeSpan timeout)
|
||||
public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
|
||||
{
|
||||
var completed = await Task.WhenAny(task, Task.Delay(timeout));
|
||||
if (completed != task)
|
||||
{
|
||||
throw new TimeoutException();
|
||||
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
|
||||
}
|
||||
|
||||
await task;
|
||||
}
|
||||
|
||||
public static Task<T> OrTimeout<T>(this Task<T> task, int milliseconds = DefaultTimeout)
|
||||
public static Task<T> OrTimeout<T>(this Task<T> task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
|
||||
{
|
||||
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
|
||||
return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
|
||||
}
|
||||
|
||||
public static async Task<T> OrTimeout<T>(this Task<T> task, TimeSpan timeout)
|
||||
public static async Task<T> OrTimeout<T>(this Task<T> task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
|
||||
{
|
||||
var completed = await Task.WhenAny(task, Task.Delay(timeout));
|
||||
if (completed != task)
|
||||
{
|
||||
throw new TimeoutException();
|
||||
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
|
||||
}
|
||||
|
||||
return await task;
|
||||
}
|
||||
|
||||
private static string GetMessage(string memberName, string filePath, int? lineNumber)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(memberName))
|
||||
{
|
||||
return $"Operation in {memberName} timed out at {filePath}:{lineNumber}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "Operation timed out";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
|
||||
using (var httpClient = _testServer.CreateClient())
|
||||
{
|
||||
var connection = new HubConnection(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), loggerFactory);
|
||||
var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory);
|
||||
try
|
||||
{
|
||||
await connection.StartAsync(TransportType.LongPolling, httpClient);
|
||||
|
|
@ -133,13 +133,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
tcs.TrySetResult((string)a[0]);
|
||||
});
|
||||
|
||||
await connection.Invoke<Task>("CallEcho", originalMessage);
|
||||
await connection.Invoke<Task>("CallEcho", originalMessage).OrTimeout();
|
||||
|
||||
Assert.Equal(originalMessage, await tcs.Task.OrTimeout());
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.DisposeAsync();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Newtonsoft.Json;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
||||
{
|
||||
// This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer).
|
||||
// We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks
|
||||
// don't cause problems.
|
||||
public class HubConnectionProtocolTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task InvokeSendsAnInvocationMessage()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
var invokeTask = hubConnection.Invoke("Foo", typeof(void));
|
||||
|
||||
var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout();
|
||||
|
||||
Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeCompletedWhenCompletionMessageReceived()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
var invokeTask = hubConnection.Invoke("Foo", typeof(void));
|
||||
|
||||
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
|
||||
|
||||
await invokeTask.OrTimeout();
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeYieldsResultWhenCompletionMessageReceived()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
var invokeTask = hubConnection.Invoke<int>("Foo");
|
||||
|
||||
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout();
|
||||
|
||||
Assert.Equal(42, await invokeTask.OrTimeout());
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
var invokeTask = hubConnection.Invoke<int>("Foo");
|
||||
|
||||
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout();
|
||||
|
||||
var ex = await Assert.ThrowsAsync<HubException>(() => invokeTask).OrTimeout();
|
||||
Assert.Equal("An error occurred", ex.Message);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
// This will fail (intentionally) when we support streaming!
|
||||
public async Task InvokeFailsWithErrorWhenStreamingItemReceived()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
var invokeTask = hubConnection.Invoke<int>("Foo");
|
||||
|
||||
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, result = 42 }).OrTimeout();
|
||||
|
||||
var ex = await Assert.ThrowsAsync<NotSupportedException>(() => invokeTask).OrTimeout();
|
||||
Assert.Equal("Streaming method results are not supported", ex.Message);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived()
|
||||
{
|
||||
var connection = new TestConnection();
|
||||
var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
|
||||
var handlerCalled = new TaskCompletionSource<object[]>();
|
||||
try
|
||||
{
|
||||
await hubConnection.StartAsync();
|
||||
|
||||
hubConnection.On("Foo", new[] { typeof(int), typeof(string), typeof(float) }, (a) => handlerCalled.TrySetResult(a));
|
||||
|
||||
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 1, "Foo", 2.0f } }).OrTimeout();
|
||||
|
||||
var results = await handlerCalled.Task.OrTimeout();
|
||||
Assert.Equal(3, results.Length);
|
||||
Assert.Equal(1, results[0]);
|
||||
Assert.Equal("Foo", results[1]);
|
||||
Assert.Equal(2.0f, results[2]);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await hubConnection.DisposeAsync().OrTimeout();
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,12 +2,13 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Buffers;
|
||||
using System.Net.Http;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Client.Tests;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Client;
|
||||
|
|
@ -25,14 +26,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
public void CannotCreateHubConnectionWithNullUrl()
|
||||
{
|
||||
var exception = Assert.Throws<ArgumentNullException>(
|
||||
() => new HubConnection((Uri)null, Mock.Of<IInvocationAdapter>(), Mock.Of<ILoggerFactory>()));
|
||||
() => new HubConnection((Uri)null, Mock.Of<ILoggerFactory>()));
|
||||
Assert.Equal("url", exception.ParamName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CanDisposeNotStartedHubConnection()
|
||||
{
|
||||
await new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory())
|
||||
await new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory())
|
||||
.DisposeAsync();
|
||||
}
|
||||
|
||||
|
|
@ -50,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
|
||||
|
||||
try
|
||||
{
|
||||
|
|
@ -81,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
|
||||
|
||||
await hubConnection.StartAsync(TransportType.LongPolling, httpClient);
|
||||
await hubConnection.DisposeAsync();
|
||||
|
|
@ -110,7 +111,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
|
||||
var exception =
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(async () => await hubConnection.Invoke<int>("test"));
|
||||
Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
|
||||
Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
@ -119,12 +120,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
var mockConnection = new Mock<IConnection>();
|
||||
|
||||
var exception = new InvalidOperationException();
|
||||
var mockInvocationAdapter = new Mock<IInvocationAdapter>();
|
||||
mockInvocationAdapter
|
||||
.Setup(a => a.WriteMessageAsync(It.IsAny<InvocationMessage>(), It.IsAny<Stream>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(Task.FromException(exception));
|
||||
|
||||
var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
|
||||
var mockProtocol = MockHubProtocol.Throw(exception);
|
||||
var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
|
||||
await hubConnection.StartAsync(TransportType.All, httpClient: null);
|
||||
|
||||
var actualException =
|
||||
|
|
@ -146,7 +143,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
|
||||
try
|
||||
{
|
||||
var connectedEventRaisedTcs = new TaskCompletionSource<object>();
|
||||
|
|
@ -177,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
|
||||
var closedEventTcs = new TaskCompletionSource<Exception>();
|
||||
hubConnection.Closed += e => closedEventTcs.SetResult(e);
|
||||
|
||||
|
|
@ -197,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
|
||||
.Returns(Task.FromResult<object>(null));
|
||||
|
||||
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
|
||||
|
||||
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
|
||||
await hubConnection.DisposeAsync();
|
||||
|
|
@ -216,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
|
||||
.Returns(Task.FromResult<object>(null));
|
||||
|
||||
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
|
||||
|
||||
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
|
||||
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
|
||||
|
|
@ -235,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
.Callback(() => mockConnection.Raise(c => c.Closed += null, exception))
|
||||
.Returns(Task.FromResult<object>(null));
|
||||
|
||||
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IInvocationAdapter>(), new LoggerFactory());
|
||||
var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
|
||||
|
||||
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
|
||||
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
|
||||
|
|
@ -250,23 +247,69 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
{
|
||||
var mockConnection = new Mock<IConnection>();
|
||||
|
||||
var invocationDescriptor = new InvocationDescriptor
|
||||
{
|
||||
Method = "NonExistingMethod123",
|
||||
Arguments = new object[] { true, "arg2", 123 },
|
||||
Id = Guid.NewGuid().ToString()
|
||||
};
|
||||
var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 });
|
||||
|
||||
var mockInvocationAdapter = new Mock<IInvocationAdapter>();
|
||||
mockInvocationAdapter
|
||||
.Setup(a => a.ReadMessageAsync(It.IsAny<Stream>(), It.IsAny<IInvocationBinder>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(Task.FromResult((InvocationMessage)invocationDescriptor));
|
||||
var mockProtocol = MockHubProtocol.ReturnOnParse(invocation);
|
||||
|
||||
var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
|
||||
var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
|
||||
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of<ITransport>()), httpClient: null);
|
||||
|
||||
mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { }, MessageType.Text });
|
||||
mockInvocationAdapter.Verify(a => a.ReadMessageAsync(It.IsAny<Stream>(), It.IsAny<IInvocationBinder>(), It.IsAny<CancellationToken>()), Times.Once());
|
||||
Assert.Equal(1, mockProtocol.ParseCalls);
|
||||
}
|
||||
|
||||
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
|
||||
private class MockHubProtocol : IHubProtocol
|
||||
{
|
||||
private HubMessage _parsed;
|
||||
private Exception _error;
|
||||
|
||||
public int ParseCalls { get; private set; } = 0;
|
||||
public int WriteCalls { get; private set; } = 0;
|
||||
|
||||
public MessageType MessageType => MessageType.Text;
|
||||
|
||||
public static MockHubProtocol ReturnOnParse(HubMessage parsed)
|
||||
{
|
||||
return new MockHubProtocol
|
||||
{
|
||||
_parsed = parsed
|
||||
};
|
||||
}
|
||||
|
||||
public static MockHubProtocol Throw(Exception error)
|
||||
{
|
||||
return new MockHubProtocol
|
||||
{
|
||||
_error = error
|
||||
};
|
||||
}
|
||||
|
||||
public HubMessage ParseMessage(ReadOnlySpan<byte> input, IInvocationBinder binder)
|
||||
{
|
||||
ParseCalls += 1;
|
||||
if (_error != null)
|
||||
{
|
||||
throw _error;
|
||||
}
|
||||
if (_parsed != null)
|
||||
{
|
||||
return _parsed;
|
||||
}
|
||||
|
||||
throw new InvalidOperationException("No Parsed Message provided");
|
||||
}
|
||||
|
||||
public bool TryWriteMessage(HubMessage message, IOutput output)
|
||||
{
|
||||
WriteCalls += 1;
|
||||
|
||||
if (_error != null)
|
||||
{
|
||||
throw _error;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,116 @@
|
|||
using System;
|
||||
using System.Net.Http;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Client;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
||||
{
|
||||
internal class TestConnection : IConnection
|
||||
{
|
||||
private TaskCompletionSource<object> _started = new TaskCompletionSource<object>();
|
||||
private TaskCompletionSource<object> _disposed = new TaskCompletionSource<object>();
|
||||
|
||||
private Channel<Message> _sentMessages = Channel.CreateUnbounded<Message>();
|
||||
private Channel<Message> _receivedMessages = Channel.CreateUnbounded<Message>();
|
||||
|
||||
private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource();
|
||||
private Task _receiveLoop;
|
||||
|
||||
public event Action Connected;
|
||||
public event Action<byte[], MessageType> Received;
|
||||
public event Action<Exception> Closed;
|
||||
|
||||
public Task Started => _started.Task;
|
||||
public Task Disposed => _disposed.Task;
|
||||
public ReadableChannel<Message> SentMessages => _sentMessages.In;
|
||||
public WritableChannel<Message> ReceivedMessages => _receivedMessages.Out;
|
||||
|
||||
public TestConnection()
|
||||
{
|
||||
_receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token);
|
||||
}
|
||||
|
||||
public Task DisposeAsync()
|
||||
{
|
||||
_disposed.TrySetResult(null);
|
||||
_receiveShutdownToken.Cancel();
|
||||
return _receiveLoop;
|
||||
}
|
||||
|
||||
public async Task SendAsync(byte[] data, MessageType type, CancellationToken cancellationToken)
|
||||
{
|
||||
if(!_started.Task.IsCompleted)
|
||||
{
|
||||
throw new InvalidOperationException("Connection must be started before SendAsync can be called");
|
||||
}
|
||||
|
||||
var message = new Message(data, type, endOfMessage: true);
|
||||
while (await _sentMessages.Out.WaitToWriteAsync(cancellationToken))
|
||||
{
|
||||
if (_sentMessages.Out.TryWrite(message))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw new ObjectDisposedException("Unable to send message, underlying channel was closed");
|
||||
}
|
||||
|
||||
public Task StartAsync(ITransportFactory transportFactory, HttpClient httpClient)
|
||||
{
|
||||
_started.TrySetResult(null);
|
||||
Connected?.Invoke();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public async Task<string> ReadSentTextMessageAsync()
|
||||
{
|
||||
var message = await SentMessages.ReadAsync();
|
||||
if (message.Type != MessageType.Text)
|
||||
{
|
||||
throw new InvalidOperationException($"Unexpected message of type: {message.Type}");
|
||||
}
|
||||
return Encoding.UTF8.GetString(message.Payload);
|
||||
}
|
||||
|
||||
public Task ReceiveJsonMessage(object jsonObject)
|
||||
{
|
||||
var json = JsonConvert.SerializeObject(jsonObject, Formatting.None);
|
||||
var bytes = Encoding.UTF8.GetBytes(json);
|
||||
var message = new Message(bytes, MessageType.Text);
|
||||
|
||||
return _receivedMessages.Out.WriteAsync(message);
|
||||
}
|
||||
|
||||
private async Task ReceiveLoopAsync(CancellationToken token)
|
||||
{
|
||||
try
|
||||
{
|
||||
while (!token.IsCancellationRequested)
|
||||
{
|
||||
while (await _receivedMessages.In.WaitToReadAsync(token))
|
||||
{
|
||||
while (_receivedMessages.In.TryRead(out var message))
|
||||
{
|
||||
Received?.Invoke(message.Payload, message.Type);
|
||||
}
|
||||
}
|
||||
}
|
||||
Closed?.Invoke(null);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Do nothing, we were just asked to shut down.
|
||||
Closed?.Invoke(null);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Closed?.Invoke(ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Serialization;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
||||
{
|
||||
public class JsonHubProtocolTests
|
||||
{
|
||||
public static IEnumerable<object[]> ProtocolTestData => new[]
|
||||
{
|
||||
new object[] { new InvocationMessage("123", true, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"nonBlocking\":true,\"arguments\":[1,\"Foo\",2.0]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[true]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", new object[] { null }), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[null]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}]}" },
|
||||
new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}]}" },
|
||||
|
||||
new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":1}" },
|
||||
new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":\"Foo\"}" },
|
||||
new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":2.0}" },
|
||||
new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":true}" },
|
||||
new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":null}" },
|
||||
new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
|
||||
new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
|
||||
new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
|
||||
new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
|
||||
|
||||
new object[] { CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":1}" },
|
||||
new object[] { CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":\"Foo\"}" },
|
||||
new object[] { CompletionMessage.WithResult("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":2.0}" },
|
||||
new object[] { CompletionMessage.WithResult("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":true}" },
|
||||
new object[] { CompletionMessage.WithResult("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":null}" },
|
||||
new object[] { CompletionMessage.WithError("123", "Whoops!"), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"error\":\"Whoops!\"}" },
|
||||
new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
|
||||
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
|
||||
new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
|
||||
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
|
||||
};
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(ProtocolTestData))]
|
||||
public async Task WriteMessage(HubMessage message, bool camelCase, NullValueHandling nullValueHandling, string expectedOutput)
|
||||
{
|
||||
var jsonSerializer = new JsonSerializer
|
||||
{
|
||||
NullValueHandling = nullValueHandling,
|
||||
ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
|
||||
};
|
||||
|
||||
var protocol = new JsonHubProtocol(jsonSerializer);
|
||||
var encoded = await protocol.WriteToArrayAsync(message);
|
||||
var json = Encoding.UTF8.GetString(encoded);
|
||||
|
||||
Assert.Equal(expectedOutput, json);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(ProtocolTestData))]
|
||||
public void ParseMessage(HubMessage expectedMessage, bool camelCase, NullValueHandling nullValueHandling, string input)
|
||||
{
|
||||
var jsonSerializer = new JsonSerializer
|
||||
{
|
||||
NullValueHandling = nullValueHandling,
|
||||
ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
|
||||
};
|
||||
|
||||
var binder = new TestBinder(expectedMessage);
|
||||
var protocol = new JsonHubProtocol(jsonSerializer);
|
||||
var message = protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder);
|
||||
|
||||
Assert.Equal(expectedMessage, message, TestEqualityComparer.Instance);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("", "Error reading JSON.")]
|
||||
[InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
|
||||
[InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
|
||||
[InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
|
||||
[InlineData("[42]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
|
||||
[InlineData("{}", "Missing required property 'type'.")]
|
||||
|
||||
[InlineData("{'type':1}", "Missing required property 'invocationId'.")]
|
||||
[InlineData("{'type':1,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
|
||||
[InlineData("{'type':1,'invocationId':'42','target':42}", "Expected 'target' to be of type String.")]
|
||||
[InlineData("{'type':1,'invocationId':'42','target':'foo'}", "Missing required property 'arguments'.")]
|
||||
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':{}}", "Expected 'arguments' to be of type Array.")]
|
||||
|
||||
[InlineData("{'type':2}", "Missing required property 'invocationId'.")]
|
||||
[InlineData("{'type':2,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
|
||||
[InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'result'.")]
|
||||
|
||||
[InlineData("{'type':3}", "Missing required property 'invocationId'.")]
|
||||
[InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
|
||||
[InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")]
|
||||
|
||||
[InlineData("{'type':4}", "Unknown message type: 4")]
|
||||
[InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")]
|
||||
public void InvalidMessages(string input, string expectedMessage)
|
||||
{
|
||||
var binder = new TestBinder();
|
||||
var protocol = new JsonHubProtocol(new JsonSerializer());
|
||||
var ex = Assert.Throws<FormatException>(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
|
||||
Assert.Equal(expectedMessage, ex.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")]
|
||||
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[42, 'foo'],'nonBlocking':42}", "Expected 'nonBlocking' to be of type Boolean.")]
|
||||
[InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")]
|
||||
public void InvalidMessagesWithBinder(string input, string expectedMessage)
|
||||
{
|
||||
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
|
||||
var protocol = new JsonHubProtocol(new JsonSerializer());
|
||||
var ex = Assert.Throws<FormatException>(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
|
||||
Assert.Equal(expectedMessage, ex.Message);
|
||||
}
|
||||
|
||||
private class CustomObject : IEquatable<CustomObject>
|
||||
{
|
||||
// Not intended to be a full set of things, just a smattering of sample serializations
|
||||
public string StringProp => "SignalR!";
|
||||
|
||||
public double DoubleProp => 6.2831853071;
|
||||
|
||||
public int IntProp => 42;
|
||||
|
||||
public DateTime DateTimeProp => new DateTime(2017, 4, 11);
|
||||
|
||||
public object NullProp => null;
|
||||
|
||||
public override bool Equals(object obj)
|
||||
{
|
||||
return obj is CustomObject o && Equals(o);
|
||||
}
|
||||
|
||||
public override int GetHashCode()
|
||||
{
|
||||
// This is never used in a hash table
|
||||
return 0;
|
||||
}
|
||||
|
||||
public bool Equals(CustomObject right)
|
||||
{
|
||||
// This allows the comparer below to properly compare the object in the test.
|
||||
return string.Equals(StringProp, right.StringProp, StringComparison.Ordinal) &&
|
||||
DoubleProp == right.DoubleProp &&
|
||||
IntProp == right.IntProp &&
|
||||
DateTime.Equals(DateTimeProp, right.DateTimeProp) &&
|
||||
NullProp == right.NullProp;
|
||||
}
|
||||
}
|
||||
|
||||
// Binder that works based on the expected message argument/result types :)
|
||||
private class TestBinder : IInvocationBinder
|
||||
{
|
||||
private readonly Type[] _paramTypes;
|
||||
private readonly Type _returnType;
|
||||
|
||||
public TestBinder(HubMessage expectedMessage)
|
||||
{
|
||||
switch(expectedMessage)
|
||||
{
|
||||
case InvocationMessage i:
|
||||
_paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray();
|
||||
break;
|
||||
case StreamItemMessage s:
|
||||
_returnType = s.Item?.GetType() ?? typeof(object);
|
||||
break;
|
||||
case CompletionMessage c:
|
||||
_returnType = c.Result?.GetType() ?? typeof(object);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
public TestBinder() : this(null, null) { }
|
||||
public TestBinder(Type[] paramTypes) : this(paramTypes, null) { }
|
||||
public TestBinder(Type returnType) : this(null, returnType) {}
|
||||
public TestBinder(Type[] paramTypes, Type returnType)
|
||||
{
|
||||
_paramTypes = paramTypes;
|
||||
_returnType = returnType;
|
||||
}
|
||||
|
||||
public Type[] GetParameterTypes(string methodName)
|
||||
{
|
||||
if (_paramTypes != null)
|
||||
{
|
||||
return _paramTypes;
|
||||
}
|
||||
throw new InvalidOperationException("Unexpected binder call");
|
||||
}
|
||||
|
||||
public Type GetReturnType(string invocationId)
|
||||
{
|
||||
if (_returnType != null)
|
||||
{
|
||||
return _returnType;
|
||||
}
|
||||
throw new InvalidOperationException("Unexpected binder call");
|
||||
}
|
||||
}
|
||||
|
||||
private class TestEqualityComparer : IEqualityComparer<HubMessage>
|
||||
{
|
||||
public static readonly TestEqualityComparer Instance = new TestEqualityComparer();
|
||||
|
||||
private TestEqualityComparer() { }
|
||||
|
||||
public bool Equals(HubMessage x, HubMessage y)
|
||||
{
|
||||
if (!string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y);
|
||||
}
|
||||
|
||||
private bool CompletionMessagesEqual(HubMessage x, HubMessage y)
|
||||
{
|
||||
return x is CompletionMessage left && y is CompletionMessage right &&
|
||||
string.Equals(left.Error, right.Error, StringComparison.Ordinal) &&
|
||||
Equals(left.Result, right.Result) &&
|
||||
left.HasResult == right.HasResult;
|
||||
}
|
||||
|
||||
private bool StreamItemMessagesEqual(HubMessage x, HubMessage y)
|
||||
{
|
||||
return x is StreamItemMessage left && y is StreamItemMessage right &&
|
||||
Equals(left.Item, right.Item);
|
||||
}
|
||||
|
||||
private bool InvocationMessagesEqual(HubMessage x, HubMessage y)
|
||||
{
|
||||
return x is InvocationMessage left && y is InvocationMessage right &&
|
||||
string.Equals(left.Target, right.Target, StringComparison.Ordinal) &&
|
||||
Enumerable.SequenceEqual(left.Arguments, right.Arguments) &&
|
||||
left.NonBlocking == right.NonBlocking;
|
||||
}
|
||||
|
||||
public int GetHashCode(HubMessage obj)
|
||||
{
|
||||
// We never use these in a hash-table
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<Import Project="..\..\build\common.props" />
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFrameworks>netcoreapp2.0;net46</TargetFrameworks>
|
||||
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">netcoreapp2.0</TargetFrameworks>
|
||||
<!-- TODO remove when https://github.com/Microsoft/vstest/issues/428 is resolved -->
|
||||
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
|
||||
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Remove="Protocol\**" />
|
||||
<EmbeddedResource Remove="Protocol\**" />
|
||||
<None Remove="Protocol\**" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XunitVersion)" />
|
||||
<PackageReference Include="System.ValueTuple" Version="$(CoreFxVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Service Include="{82a7f48d-3b50-4b1e-b82e-3ada8210c358}" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
|
|
@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var trackDispose = new TrackDispose();
|
||||
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose));
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<TestHub>>();
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<DisposeTrackingHub>>();
|
||||
|
||||
using (var client = new TestClient(serviceProvider))
|
||||
{
|
||||
|
|
@ -127,10 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.TaskValueMethod)).OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(MethodHub.TaskValueMethod)).OrTimeout()).Result;
|
||||
|
||||
// json serializer makes this a long
|
||||
Assert.Equal(42L, result.Result);
|
||||
Assert.Equal(42L, result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -150,10 +150,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>("echo", "hello").OrTimeout();
|
||||
var result = (await client.InvokeAsync("echo", "hello").OrTimeout()).Result;
|
||||
|
||||
Assert.Null(result.Error);
|
||||
Assert.Equal("hello", result.Result);
|
||||
Assert.Equal("hello", result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(nameof(MethodHub.MethodThatThrows))]
|
||||
[InlineData(nameof(MethodHub.MethodThatYieldsFailedTask))]
|
||||
public async Task HubMethodCanThrowOrYieldFailedTask(string methodName)
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient(serviceProvider))
|
||||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = (await client.InvokeAsync(methodName).OrTimeout());
|
||||
|
||||
Assert.Equal("BOOM!", result.Error);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -173,10 +196,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.ValueMethod)).OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(MethodHub.ValueMethod)).OrTimeout()).Result;
|
||||
|
||||
// json serializer makes this a long
|
||||
Assert.Equal(43L, result.Result);
|
||||
Assert.Equal(43L, result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -196,9 +219,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.VoidMethod)).OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(MethodHub.VoidMethod)).OrTimeout()).Result;
|
||||
|
||||
Assert.Null(result.Result);
|
||||
Assert.Null(result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -218,9 +241,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout()).Result;
|
||||
|
||||
Assert.Equal("32, 42, m, string", result.Result);
|
||||
Assert.Equal("32, 42, m, string", result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -240,9 +263,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(InheritedHub.BaseMethod), "string").OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(InheritedHub.BaseMethod), "string").OrTimeout()).Result;
|
||||
|
||||
Assert.Equal("string", result.Result);
|
||||
Assert.Equal("string", result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -262,9 +285,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(InheritedHub.VirtualMethod), 10).OrTimeout();
|
||||
var result = (await client.InvokeAsync(nameof(InheritedHub.VirtualMethod), 10).OrTimeout()).Result;
|
||||
|
||||
Assert.Equal(0L, result.Result);
|
||||
Assert.Equal(0L, result);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -284,7 +307,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.Invoke<InvocationResultDescriptor>(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
|
||||
var result = await client.InvokeAsync(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
|
||||
|
||||
Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error);
|
||||
|
||||
|
|
@ -326,15 +349,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
await firstClient.Invoke(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
|
||||
|
||||
foreach (var result in await Task.WhenAll(
|
||||
firstClient.Read<InvocationDescriptor>(),
|
||||
secondClient.Read<InvocationDescriptor>()).OrTimeout())
|
||||
firstClient.Read(),
|
||||
secondClient.Read()).OrTimeout())
|
||||
{
|
||||
Assert.Equal("Broadcast", result.Method);
|
||||
Assert.Equal(1, result.Arguments.Length);
|
||||
Assert.Equal("test", result.Arguments[0]);
|
||||
var invocation = Assert.IsType<InvocationMessage>(result);
|
||||
Assert.Equal("Broadcast", invocation.Target);
|
||||
Assert.Equal(1, invocation.Arguments.Length);
|
||||
Assert.Equal("test", invocation.Arguments[0]);
|
||||
}
|
||||
|
||||
// kill the connections
|
||||
|
|
@ -360,23 +384,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
var result = await firstClient.Invoke<InvocationResultDescriptor>(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
|
||||
var result = (await firstClient.InvokeAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout()).Result;
|
||||
|
||||
// check that 'firstConnection' hasn't received the group send
|
||||
Assert.Null(result.Id);
|
||||
Assert.Null(firstClient.TryRead());
|
||||
|
||||
// check that 'secondConnection' hasn't received the group send
|
||||
Assert.Null(await secondClient.TryRead<InvocationDescriptor>().OrTimeout());
|
||||
Assert.Null(secondClient.TryRead());
|
||||
|
||||
result = await secondClient.Invoke<InvocationResultDescriptor>(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout();
|
||||
Assert.Null(result.Id);
|
||||
result = (await secondClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout()).Result;
|
||||
|
||||
await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
|
||||
|
||||
// check that 'secondConnection' has received the group send
|
||||
var descriptor = await secondClient.Read<InvocationDescriptor>().OrTimeout();
|
||||
Assert.Equal("Send", descriptor.Method);
|
||||
Assert.Equal(1, descriptor.Arguments.Length);
|
||||
Assert.Equal("test", descriptor.Arguments[0]);
|
||||
var hubMessage = await secondClient.Read().OrTimeout();
|
||||
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
|
||||
Assert.Equal("Send", invocation.Target);
|
||||
Assert.Equal(1, invocation.Arguments.Length);
|
||||
Assert.Equal("test", invocation.Arguments[0]);
|
||||
|
||||
// kill the connections
|
||||
firstClient.Dispose();
|
||||
|
|
@ -397,7 +422,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
await client.Invoke(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
|
||||
await client.SendInvocationAsync(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
|
@ -421,13 +446,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
await firstClient.Invoke(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
|
||||
|
||||
// check that 'secondConnection' has received the group send
|
||||
var result = await secondClient.Read<InvocationDescriptor>().OrTimeout();
|
||||
Assert.Equal("Send", result.Method);
|
||||
Assert.Equal(1, result.Arguments.Length);
|
||||
Assert.Equal("test", result.Arguments[0]);
|
||||
var hubMessage = await secondClient.Read().OrTimeout();
|
||||
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
|
||||
Assert.Equal("Send", invocation.Target);
|
||||
Assert.Equal(1, invocation.Arguments.Length);
|
||||
Assert.Equal("test", invocation.Arguments[0]);
|
||||
|
||||
// kill the connections
|
||||
firstClient.Dispose();
|
||||
|
|
@ -452,13 +478,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
await firstClient.Invoke(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
|
||||
|
||||
// check that 'secondConnection' has received the group send
|
||||
var result = await secondClient.Read<InvocationDescriptor>().OrTimeout();
|
||||
Assert.Equal("Send", result.Method);
|
||||
Assert.Equal(1, result.Arguments.Length);
|
||||
Assert.Equal("test", result.Arguments[0]);
|
||||
var hubMessage = await secondClient.Read().OrTimeout();
|
||||
var invocation = Assert.IsType<InvocationMessage>(hubMessage);
|
||||
Assert.Equal("Send", invocation.Target);
|
||||
Assert.Equal(1, invocation.Arguments.Length);
|
||||
Assert.Equal("test", invocation.Arguments[0]);
|
||||
|
||||
// kill the connections
|
||||
firstClient.Dispose();
|
||||
|
|
@ -575,7 +602,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
public override Task OnDisconnectedAsync(Exception e)
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public void MethodThatThrows()
|
||||
{
|
||||
throw new InvalidOperationException("BOOM!");
|
||||
}
|
||||
|
||||
public Task MethodThatYieldsFailedTask()
|
||||
{
|
||||
return Task.FromException(new InvalidOperationException("BOOM!"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -611,11 +648,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
private class TestHub : Hub
|
||||
private class DisposeTrackingHub : Hub
|
||||
{
|
||||
private TrackDispose _trackDispose;
|
||||
|
||||
public TestHub(TrackDispose trackDispose)
|
||||
public DisposeTrackingHub(TrackDispose trackDispose)
|
||||
{
|
||||
_trackDispose = trackDispose;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,29 +2,30 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Collections.Generic;
|
||||
using System.Security.Claims;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public class TestClient : IDisposable
|
||||
public class TestClient : IDisposable, IInvocationBinder
|
||||
{
|
||||
private static int _id;
|
||||
private IInvocationAdapter _adapter;
|
||||
private IHubProtocol _protocol;
|
||||
private CancellationTokenSource _cts;
|
||||
private TestBinder _binder;
|
||||
|
||||
public Connection Connection;
|
||||
public IChannelConnection<Message> Application { get; }
|
||||
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
|
||||
|
||||
public TestClient(IServiceProvider serviceProvider, string format = "json")
|
||||
public TestClient(IServiceProvider serviceProvider)
|
||||
{
|
||||
var transportToApplication = Channel.CreateUnbounded<Message>();
|
||||
var applicationToTransport = Channel.CreateUnbounded<Message>();
|
||||
|
|
@ -33,62 +34,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var transport = ChannelConnection.Create<Message>(input: transportToApplication, output: applicationToTransport);
|
||||
|
||||
Connection = new Connection(Guid.NewGuid().ToString(), transport);
|
||||
Connection.Metadata["formatType"] = format;
|
||||
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
|
||||
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
|
||||
|
||||
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
|
||||
_adapter = invocationAdapter.GetInvocationAdapter(format);
|
||||
|
||||
_binder = new TestBinder();
|
||||
_protocol = new JsonHubProtocol(new JsonSerializer());
|
||||
|
||||
_cts = new CancellationTokenSource();
|
||||
}
|
||||
|
||||
public async Task<T> Invoke<T>(string methodName, params object[] args) where T : InvocationMessage
|
||||
public async Task<CompletionMessage> InvokeAsync(string methodName, params object[] args)
|
||||
{
|
||||
await Invoke(methodName, args);
|
||||
var invocationId = await SendInvocationAsync(methodName, args);
|
||||
|
||||
return await Read<T>();
|
||||
}
|
||||
|
||||
public async Task Invoke(string methodName, params object[] args)
|
||||
{
|
||||
var stream = new MemoryStream();
|
||||
await _adapter.WriteMessageAsync(new InvocationDescriptor
|
||||
while (true)
|
||||
{
|
||||
Arguments = args,
|
||||
Method = methodName
|
||||
},
|
||||
stream);
|
||||
var message = await Read();
|
||||
|
||||
await Application.Output.WriteAsync(new Message(stream.ToArray(), MessageType.Binary, endOfMessage: true));
|
||||
}
|
||||
|
||||
public async Task<T> Read<T>() where T : InvocationMessage
|
||||
{
|
||||
while (await Application.Input.WaitToReadAsync(_cts.Token))
|
||||
{
|
||||
var value = await TryRead<T>();
|
||||
|
||||
if (value != null)
|
||||
if (!string.Equals(message.InvocationId, invocationId))
|
||||
{
|
||||
return value;
|
||||
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
|
||||
}
|
||||
|
||||
if (message == null)
|
||||
{
|
||||
throw new InvalidOperationException("Connection aborted!");
|
||||
}
|
||||
|
||||
switch (message)
|
||||
{
|
||||
case StreamItemMessage result:
|
||||
throw new NotSupportedException("TestClient does not support streaming!");
|
||||
case CompletionMessage completion:
|
||||
return completion;
|
||||
default:
|
||||
throw new NotSupportedException("TestClient does not support receiving invocations!");
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public async Task<T> TryRead<T>() where T : InvocationMessage
|
||||
public async Task<string> SendInvocationAsync(string methodName, params object[] args)
|
||||
{
|
||||
Message message;
|
||||
if (Application.Input.TryRead(out message))
|
||||
{
|
||||
var value = await _adapter.ReadMessageAsync(new MemoryStream(message.Payload), _binder);
|
||||
return value as T;
|
||||
}
|
||||
var invocationId = GetInvocationId();
|
||||
var payload = await _protocol.WriteToArrayAsync(new InvocationMessage(invocationId, nonBlocking: false, target: methodName, arguments: args));
|
||||
|
||||
await Application.Output.WriteAsync(new Message(payload, _protocol.MessageType, endOfMessage: true));
|
||||
|
||||
return invocationId;
|
||||
}
|
||||
|
||||
public async Task<HubMessage> Read()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var message = TryRead();
|
||||
|
||||
if (message == null)
|
||||
{
|
||||
if (!await Application.Input.WaitToReadAsync())
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public HubMessage TryRead()
|
||||
{
|
||||
if (Application.Input.TryRead(out var message))
|
||||
{
|
||||
return _protocol.ParseMessage(message.Payload, this);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
@ -98,18 +117,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
Connection.Dispose();
|
||||
}
|
||||
|
||||
private class TestBinder : IInvocationBinder
|
||||
private static string GetInvocationId()
|
||||
{
|
||||
public Type[] GetParameterTypes(string methodName)
|
||||
{
|
||||
// TODO: Possibly support actual client methods
|
||||
return new[] { typeof(object) };
|
||||
}
|
||||
return Guid.NewGuid().ToString("N");
|
||||
}
|
||||
|
||||
public Type GetReturnType(string invocationId)
|
||||
{
|
||||
return typeof(object);
|
||||
}
|
||||
Type[] IInvocationBinder.GetParameterTypes(string methodName)
|
||||
{
|
||||
// TODO: Possibly support actual client methods
|
||||
return new[] { typeof(object) };
|
||||
}
|
||||
|
||||
Type IInvocationBinder.GetReturnType(string invocationId)
|
||||
{
|
||||
return typeof(object);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue