diff --git a/SignalR.sln b/SignalR.sln
index cb06185f7b..bdf0240d45 100644
--- a/SignalR.sln
+++ b/SignalR.sln
@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
-VisualStudioVersion = 15.0.26228.9
+VisualStudioVersion = 15.0.26411.1
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject
@@ -74,6 +74,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "client-ts", "client-ts", "{
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Microbenchmarks", "test\Microsoft.AspNetCore.SignalR.Microbenchmarks\Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj", "{96771B3F-4D18-41A7-A75B-FF38E76AAC89}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -180,6 +182,10 @@ Global
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Debug|Any CPU.Build.0 = Debug|Any CPU
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.ActiveCfg = Release|Any CPU
{96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.Build.0 = Release|Any CPU
+ {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -211,5 +217,6 @@ Global
{333526A4-633B-491A-AC45-CC62A0012D1C} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B}
{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
+ {75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
EndGlobalSection
EndGlobal
diff --git a/build/dependencies.props b/build/dependencies.props
index 17b9f395ac..266b49aaff 100644
--- a/build/dependencies.props
+++ b/build/dependencies.props
@@ -1,6 +1,7 @@
-
+
0.4.0-*
+ 4.4.0-*
2.0.0-*
0.1.0-*
4.3.0
diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts
index c3eae1fb51..82b8d0a591 100644
--- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts
+++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts
@@ -4,7 +4,7 @@ import { DataReceived, ConnectionClosed } from "../Microsoft.AspNetCore.SignalR.
import { TransportType, ITransport } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports"
describe("HubConnection", () => {
- it("completes pending invocations when stopped", async (done) => {
+ it("completes pending invocations when stopped", async done => {
let connection: IConnection = {
start(transportType: TransportType | ITransport): Promise {
return Promise.resolve();
@@ -27,18 +27,18 @@ describe("HubConnection", () => {
let hubConnection = new HubConnection(connection);
var invokePromise = hubConnection.invoke("testMethod");
hubConnection.stop();
- invokePromise
- .then(() => {
- fail();
- done();
- })
- .catch((error: Error) => {
- expect(error.message).toBe("Invocation cancelled due to connection being closed.");
- done();
- });
+
+ try {
+ await invokePromise;
+ fail();
+ }
+ catch (e) {
+ expect(e.message).toBe("Invocation cancelled due to connection being closed.");
+ }
+ done();
});
- it("completes pending invocations when connection is lost", async (done) => {
+ it("completes pending invocations when connection is lost", async done => {
let connection: IConnection = {
start(transportType: TransportType | ITransport): Promise {
return Promise.resolve();
@@ -60,17 +60,92 @@ describe("HubConnection", () => {
let hubConnection = new HubConnection(connection);
var invokePromise = hubConnection.invoke("testMethod");
- invokePromise
- .then(() => {
- fail();
- done();
- })
- .catch((error: Error) => {
- expect(error.message).toBe("Connection lost");
- done();
- });
-
// Typically this would be called by the transport
connection.onClosed(new Error("Connection lost"));
+
+ try {
+ await invokePromise;
+ fail();
+ }
+ catch (e) {
+ expect(e.message).toBe("Connection lost");
+ }
+ done();
+ });
+
+ it("sends invocations as nonblocking", async done => {
+ let dataSent: string;
+ let connection: IConnection = {
+ start(transportType: TransportType): Promise {
+ return Promise.resolve();
+ },
+
+ send(data: any): Promise {
+ dataSent = data;
+ return Promise.resolve();
+ },
+
+ stop(): void {
+ if (this.onClosed) {
+ this.onClosed();
+ }
+ },
+
+ onDataReceived: null,
+ onClosed: null
+ };
+
+ let hubConnection = new HubConnection(connection);
+ let invokePromise = hubConnection.invoke("testMethod");
+
+ expect(JSON.parse(dataSent).nonblocking).toBe(false);
+
+ // will clean pending promises
+ connection.onClosed();
+
+ try {
+ await invokePromise;
+ fail(); // exception is expected because the call has not completed
+ }
+ catch (e) {
+ }
+ done();
+ });
+
+ it("rejects streaming responses", async done => {
+ let connection: IConnection = {
+ start(transportType: TransportType): Promise {
+ return Promise.resolve();
+ },
+
+ send(data: any): Promise {
+ return Promise.resolve();
+ },
+
+ stop(): void {
+ if (this.onClosed) {
+ this.onClosed();
+ }
+ },
+
+ onDataReceived: null,
+ onClosed: null
+ };
+
+ let hubConnection = new HubConnection(connection);
+ let invokePromise = hubConnection.invoke("testMethod");
+
+ connection.onDataReceived("{ \"type\": 2, \"invocationId\": \"0\", \"result\": null }");
+ connection.onClosed();
+
+ try {
+ await invokePromise;
+ fail();
+ }
+ catch (e) {
+ expect(e.message).toBe("Streaming is not supported.");
+ }
+
+ done();
});
});
\ No newline at end of file
diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts
index 7f6c78a95b..4967e0078d 100644
--- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts
+++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts
@@ -3,16 +3,31 @@ import { IConnection } from "./IConnection"
import { Connection } from "./Connection"
import { TransportType } from "./Transports"
-interface InvocationDescriptor {
- readonly Id: string;
- readonly Method: string;
- readonly Arguments: Array;
+
+const enum MessageType {
+ Invocation = 1,
+ Result,
+ Completion
}
-interface InvocationResultDescriptor {
- readonly Id: string;
- readonly Error: string;
- readonly Result: any;
+interface HubMessage {
+ readonly type: MessageType;
+ readonly invocationId: string;
+}
+
+interface InvocationMessage extends HubMessage {
+ readonly target: string;
+ readonly arguments: Array;
+ readonly nonblocking?: boolean;
+}
+
+interface ResultMessage extends HubMessage {
+ readonly result?: any;
+}
+
+interface CompletionMessage extends HubMessage {
+ readonly error?: string;
+ readonly result?: any;
}
export { Connection } from "./Connection"
@@ -20,7 +35,7 @@ export { TransportType } from "./Transports"
export class HubConnection {
private connection: IConnection;
- private callbacks: Map void>;
+ private callbacks: Map void>;
private methods: Map void>;
private id: number;
private connectionClosedCallback: ConnectionClosed;
@@ -40,7 +55,7 @@ export class HubConnection {
this.onConnectionClosed(error);
}
- this.callbacks = new Map void>();
+ this.callbacks = new Map void>();
this.methods = new Map void>();
this.id = 0;
}
@@ -51,34 +66,49 @@ export class HubConnection {
if (!data) {
return;
}
- var descriptor = JSON.parse(data);
- if (descriptor.Method === undefined) {
- let invocationResult: InvocationResultDescriptor = descriptor;
- let callback = this.callbacks.get(invocationResult.Id);
- if (callback != null) {
- callback(invocationResult);
- this.callbacks.delete(invocationResult.Id);
+
+ var message = JSON.parse(data);
+ switch (message.type) {
+ case MessageType.Invocation:
+ this.InvokeClientMethod(message);
+ break;
+ case MessageType.Result:
+ // TODO: Streaming (MessageType.Result) currently not supported - callback will throw
+ case MessageType.Completion:
+ let callback = this.callbacks.get(message.invocationId);
+ if (callback != null) {
+ callback(message);
+ this.callbacks.delete(message.invocationId);
+ }
+ break;
+ default:
+ console.log("Invalid message type: " + data);
+ break;
+ }
+ }
+
+ private InvokeClientMethod(invocationMessage: InvocationMessage) {
+ let method = this.methods.get(invocationMessage.target);
+ if (method) {
+ method.apply(this, invocationMessage.arguments);
+ if (!invocationMessage.nonblocking) {
+ // TODO: send result back to the server?
}
}
else {
- let invocation: InvocationDescriptor = descriptor;
- let method = this.methods[invocation.Method];
- if (method != null) {
- // TODO: bind? args?
- method.apply(this, invocation.Arguments);
- }
+ console.log(`No client method with the name '${invocationMessage.target}' found.`);
}
}
private onConnectionClosed(error: Error) {
- let errorInvocationResult = {
- Id: "-1",
- Error: error ? error.message : "Invocation cancelled due to connection being closed.",
- Result: null
- } as InvocationResultDescriptor;
+ let errorCompletionMessage = {
+ type: MessageType.Completion,
+ invocationId: "-1",
+ error: error ? error.message : "Invocation cancelled due to connection being closed.",
+ };
this.callbacks.forEach(callback => {
- callback(errorInvocationResult);
+ callback(errorCompletionMessage);
});
this.callbacks.clear();
@@ -99,19 +129,27 @@ export class HubConnection {
let id = this.id;
this.id++;
- let invocationDescriptor: InvocationDescriptor = {
- "Id": id.toString(),
- "Method": methodName,
- "Arguments": args
+ let invocationDescriptor: InvocationMessage = {
+ type: MessageType.Invocation,
+ invocationId: id.toString(),
+ target: methodName,
+ arguments: args,
+ nonblocking: false
};
let p = new Promise((resolve, reject) => {
- this.callbacks.set(invocationDescriptor.Id, (invocationResult: InvocationResultDescriptor) => {
- if (invocationResult.Error != null) {
- reject(new Error(invocationResult.Error));
+ this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: CompletionMessage | ResultMessage) => {
+ if (invocationEvent.type === MessageType.Completion) {
+ let completionMessage = invocationEvent;
+ if (completionMessage.error) {
+ reject(new Error(completionMessage.error));
+ }
+ else {
+ resolve(completionMessage.result);
+ }
}
else {
- resolve(invocationResult.Result);
+ reject(new Error("Streaming is not supported."))
}
});
@@ -119,7 +157,7 @@ export class HubConnection {
this.connection.send(JSON.stringify(invocationDescriptor))
.catch(e => {
reject(e);
- this.callbacks.delete(invocationDescriptor.Id);
+ this.callbacks.delete(invocationDescriptor.invocationId);
});
});
@@ -127,7 +165,7 @@ export class HubConnection {
}
on(methodName: string, method: (...args: any[]) => void) {
- this.methods[methodName] = method;
+ this.methods.set(methodName, method);
}
set onClosed(callback: ConnectionClosed) {
diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs
index 9a513eacda..771101f605 100644
--- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs
+++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Builder;
@@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
app.UseDeveloperExceptionPage();
}
- app.UseStaticFiles();
+ app.UseFileServer();
app.UseSockets(options => options.MapEndpoint("/echo"));
app.UseSignalR(routes =>
{
diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html
new file mode 100644
index 0000000000..d06a274aff
--- /dev/null
+++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html
@@ -0,0 +1,13 @@
+
+
+
+
+ SignalR Tests
+
+
+ SignalR Tests
+
+
+
diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs
index 049e20551d..a277899d00 100644
--- a/samples/ClientSample/HubSample.cs
+++ b/samples/ClientSample/HubSample.cs
@@ -5,7 +5,6 @@ using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.Extensions.CommandLineUtils;
using Microsoft.Extensions.Logging;
@@ -33,7 +32,7 @@ namespace ClientSample
var loggerFactory = new LoggerFactory();
Console.WriteLine("Connecting to {0}", baseUrl);
- var connection = new HubConnection(new Uri(baseUrl), new JsonNetInvocationAdapter(), loggerFactory);
+ var connection = new HubConnection(new Uri(baseUrl), loggerFactory);
try
{
await connection.StartAsync();
diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs
index f7fa6e25b3..899a3fc625 100644
--- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs
+++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@@ -22,6 +22,7 @@ namespace SocialWeather
public void OnConnectedAsync(Connection connection)
{
+ connection.Metadata[ConnectionMetadataNames.Format] = "json";
_connectionList.Add(connection);
}
@@ -34,11 +35,11 @@ namespace SocialWeather
{
foreach (var connection in _connectionList)
{
- var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get("formatType"));
+ var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get(ConnectionMetadataNames.Format));
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
- var context = (HttpContext)connection.Metadata[typeof(HttpContext)];
+ var context = (HttpContext)connection.Metadata[ConnectionMetadataNames.HttpContext];
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? MessageType.Binary
diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs
index 5eec621674..3f68fa8f7d 100644
--- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs
+++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs
@@ -17,7 +17,7 @@ namespace SocketsSample.EndPoints
{
Connections.Add(connection);
- await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
+ await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
try
{
@@ -37,7 +37,7 @@ namespace SocketsSample.EndPoints
{
Connections.Remove(connection);
- await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
+ await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata[ConnectionMetadataNames.Transport]})");
}
}
diff --git a/samples/SocketsSample/LineInvocationAdapter.cs b/samples/SocketsSample/LineInvocationAdapter.cs
deleted file mode 100644
index e2a17ed1a5..0000000000
--- a/samples/SocketsSample/LineInvocationAdapter.cs
+++ /dev/null
@@ -1,91 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.IO;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
-using Microsoft.AspNetCore.SignalR;
-
-namespace SocketsSample
-{
- public class LineInvocationAdapter : IInvocationAdapter
- {
- public async Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
- {
- var streamReader = new StreamReader(stream);
- var line = await streamReader.ReadLineAsync();
- if (line == null)
- {
- return null;
- }
-
- var values = line.Split(',');
-
- var type = values[0].Substring(0, 2);
- var id = values[0].Substring(2);
-
- if (type.Equals("RI"))
- {
- var resultType = values[1].Substring(0, 1);
- var result = values[1].Substring(1);
- return new InvocationResultDescriptor()
- {
- Id = id,
- Result = resultType.Equals("E") ? null : result,
- Error = resultType.Equals("E") ? result : null,
- };
- }
- else
- {
- var method = values[1].Substring(1);
-
- return new InvocationDescriptor
- {
- Id = id,
- Method = method,
- Arguments = values.Skip(2).Zip(binder.GetParameterTypes(method), (v, t) => Convert.ChangeType(v, t)).ToArray()
- };
- }
- }
-
- public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
- {
- var invocationDescriptor = message as InvocationDescriptor;
- if (invocationDescriptor != null)
- {
- return WriteInvocationDescriptorAsync(invocationDescriptor, stream);
- }
- else
- {
- return WriteInvocationResultAsync((InvocationResultDescriptor)message, stream);
- }
- }
-
- private Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
- {
- var msg = $"CI{invocationDescriptor.Id},M{invocationDescriptor.Method},{string.Join(",", invocationDescriptor.Arguments.Select(a => a.ToString()))}\n";
- return WriteAsync(msg, stream);
- }
-
- private Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
- {
- if (string.IsNullOrEmpty(resultDescriptor.Error))
- {
- return WriteAsync($"RI{resultDescriptor.Id},E{resultDescriptor.Error}\n", stream);
- }
- else
- {
- return WriteAsync($"RI{resultDescriptor.Id},R{(resultDescriptor.Result != null ? resultDescriptor.Result.ToString() : string.Empty)}\n", stream);
- }
- }
-
- private async Task WriteAsync(string msg, Stream stream)
- {
- var writer = new StreamWriter(stream);
- await writer.WriteAsync(msg);
- await writer.FlushAsync();
- }
- }
-}
diff --git a/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs b/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs
deleted file mode 100644
index 9d5501e8a7..0000000000
--- a/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs
+++ /dev/null
@@ -1,168 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
-using Google.Protobuf;
-using Microsoft.AspNetCore.SignalR;
-using Microsoft.Extensions.DependencyInjection;
-
-namespace SocketsSample.Protobuf
-{
- public class ProtobufInvocationAdapter : IInvocationAdapter
- {
- private IServiceProvider _serviceProvider;
-
- public ProtobufInvocationAdapter(IServiceProvider serviceProvider)
- {
- _serviceProvider = serviceProvider;
- }
-
- public Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken)
- {
- return Task.Run(() => CreateInvocationMessageInt(stream, binder));
- }
-
- public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken)
- {
- throw new NotImplementedException();
- }
-
- private Task CreateInvocationMessageInt(Stream stream, IInvocationBinder binder)
- {
- var inputStream = new CodedInputStream(stream, leaveOpen: true);
- var messageKind = new RpcMessageKind();
- inputStream.ReadMessage(messageKind);
- if(messageKind.MessageKind == RpcMessageKind.Types.Kind.Invocation)
- {
- return CreateInvocationDescriptorInt(inputStream, binder);
- }
- else
- {
- return CreateInvocationResultDescriptorInt(inputStream, binder);
- }
- }
-
- private Task CreateInvocationResultDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
- {
- throw new NotImplementedException("Not yet implemented for Protobuf");
- }
-
- private Task CreateInvocationDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder)
- {
- var invocationHeader = new RpcInvocationHeader();
- inputStream.ReadMessage(invocationHeader);
- var argumentTypes = binder.GetParameterTypes(invocationHeader.Name);
-
- var invocationDescriptor = new InvocationDescriptor();
- invocationDescriptor.Method = invocationHeader.Name;
- invocationDescriptor.Id = invocationHeader.Id.ToString();
- invocationDescriptor.Arguments = new object[argumentTypes.Length];
-
- var primitiveParser = PrimitiveValue.Parser;
-
- for (var i = 0; i < argumentTypes.Length; i++)
- {
- if (typeof(int) == argumentTypes[i])
- {
- var value = new PrimitiveValue();
- inputStream.ReadMessage(value);
- invocationDescriptor.Arguments[i] = value.Int32Value;
- }
- else if (typeof(string) == argumentTypes[i])
- {
- var value = new PrimitiveValue();
- inputStream.ReadMessage(value);
- invocationDescriptor.Arguments[i] = value.StringValue;
- }
- else
- {
- var serializer = _serviceProvider.GetRequiredService();
- invocationDescriptor.Arguments[i] = serializer.GetValue(inputStream, argumentTypes[i]);
- }
- }
-
- return Task.FromResult(invocationDescriptor);
- }
-
- public async Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream)
- {
- var outputStream = new CodedOutputStream(stream, leaveOpen: true);
- outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result });
-
- var resultHeader = new RpcInvocationResultHeader
- {
- Id = int.Parse(resultDescriptor.Id),
- HasResult = resultDescriptor.Result != null
- };
-
- if (resultDescriptor.Error != null)
- {
- resultHeader.Error = resultDescriptor.Error;
- }
-
- outputStream.WriteMessage(resultHeader);
-
- if (string.IsNullOrEmpty(resultHeader.Error) && resultDescriptor.Result != null)
- {
- var result = resultDescriptor.Result;
-
- if (result.GetType() == typeof(int))
- {
- outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result });
- }
- else if (result.GetType() == typeof(string))
- {
- outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result });
- }
- else
- {
- var serializer = _serviceProvider.GetRequiredService();
- var message = serializer.GetMessage(result);
- outputStream.WriteMessage(message);
- }
- }
-
- outputStream.Flush();
- await stream.FlushAsync();
- }
-
- public async Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream)
- {
- var outputStream = new CodedOutputStream(stream, leaveOpen: true);
- outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation });
-
- var invocationHeader = new RpcInvocationHeader()
- {
- Id = 0,
- Name = invocationDescriptor.Method,
- NumArgs = invocationDescriptor.Arguments.Length
- };
-
- outputStream.WriteMessage(invocationHeader);
-
- foreach (var arg in invocationDescriptor.Arguments)
- {
- if (arg.GetType() == typeof(int))
- {
- outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg });
- }
- else if (arg.GetType() == typeof(string))
- {
- outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg });
- }
- else
- {
- var serializer = _serviceProvider.GetRequiredService();
- var message = serializer.GetMessage(arg);
- outputStream.WriteMessage(message);
- }
- }
-
- outputStream.Flush();
- await stream.FlushAsync();
- }
- }
-}
diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.cs b/samples/SocketsSample/Protobuf/RpcInvocation.cs
deleted file mode 100644
index a20fc10e4f..0000000000
--- a/samples/SocketsSample/Protobuf/RpcInvocation.cs
+++ /dev/null
@@ -1,848 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-// Generated by the protocol buffer compiler. DO NOT EDIT!
-// source: RpcInvocation.proto
-#pragma warning disable 1591, 0612, 3021
-#region Designer generated code
-
-using pb = global::Google.Protobuf;
-using pbc = global::Google.Protobuf.Collections;
-using pbr = global::Google.Protobuf.Reflection;
-using scg = global::System.Collections.Generic;
-/// Holder for reflection information generated from RpcInvocation.proto
-public static partial class RpcInvocationReflection {
-
- #region Descriptor
- /// File descriptor for RpcInvocation.proto
- public static pbr::FileDescriptor Descriptor {
- get { return descriptor; }
- }
- private static pbr::FileDescriptor descriptor;
-
- static RpcInvocationReflection() {
- byte[] descriptorData = global::System.Convert.FromBase64String(
- string.Concat(
- "ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l",
- "c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k",
- "EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u",
- "SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD",
- "IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF",
- "EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp",
- "dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY",
- "AiABKAlIAEIICgZvbmVvZl8iKgoNUGVyc29uTWVzc2FnZRIMCgROYW1lGAEg",
- "ASgJEgsKA0FnZRgCIAEoA2IGcHJvdG8z"));
- descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { },
- new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null),
- new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null),
- new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null),
- new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null),
- new pbr::GeneratedClrTypeInfo(typeof(global::PersonMessage), global::PersonMessage.Parser, new[]{ "Name", "Age" }, null, null, null)
- }));
- }
- #endregion
-
-}
-#region Messages
-public sealed partial class RpcMessageKind : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcMessageKind());
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcMessageKind() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcMessageKind(RpcMessageKind other) : this() {
- messageKind_ = other.messageKind_;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcMessageKind Clone() {
- return new RpcMessageKind(this);
- }
-
- /// Field number for the "MessageKind" field.
- public const int MessageKindFieldNumber = 1;
- private global::RpcMessageKind.Types.Kind messageKind_ = 0;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public global::RpcMessageKind.Types.Kind MessageKind {
- get { return messageKind_; }
- set {
- messageKind_ = value;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as RpcMessageKind);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(RpcMessageKind other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if (MessageKind != other.MessageKind) return false;
- return true;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (MessageKind != 0) hash ^= MessageKind.GetHashCode();
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (MessageKind != 0) {
- output.WriteRawTag(8);
- output.WriteEnum((int) MessageKind);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (MessageKind != 0) {
- size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind);
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(RpcMessageKind other) {
- if (other == null) {
- return;
- }
- if (other.MessageKind != 0) {
- MessageKind = other.MessageKind;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- input.SkipLastField();
- break;
- case 8: {
- messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum();
- break;
- }
- }
- }
- }
-
- #region Nested types
- /// Container for nested types declared in the RpcMessageKind message type.
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static partial class Types {
- public enum Kind {
- [pbr::OriginalName("Result")] Result = 0,
- [pbr::OriginalName("Invocation")] Invocation = 1,
- }
-
- }
- #endregion
-
-}
-
-public sealed partial class RpcInvocationHeader : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationHeader());
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationHeader() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationHeader(RpcInvocationHeader other) : this() {
- name_ = other.name_;
- id_ = other.id_;
- numArgs_ = other.numArgs_;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationHeader Clone() {
- return new RpcInvocationHeader(this);
- }
-
- /// Field number for the "Name" field.
- public const int NameFieldNumber = 1;
- private string name_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string Name {
- get { return name_; }
- set {
- name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- /// Field number for the "Id" field.
- public const int IdFieldNumber = 2;
- private int id_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int Id {
- get { return id_; }
- set {
- id_ = value;
- }
- }
-
- /// Field number for the "NumArgs" field.
- public const int NumArgsFieldNumber = 3;
- private int numArgs_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int NumArgs {
- get { return numArgs_; }
- set {
- numArgs_ = value;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as RpcInvocationHeader);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(RpcInvocationHeader other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if (Name != other.Name) return false;
- if (Id != other.Id) return false;
- if (NumArgs != other.NumArgs) return false;
- return true;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (Name.Length != 0) hash ^= Name.GetHashCode();
- if (Id != 0) hash ^= Id.GetHashCode();
- if (NumArgs != 0) hash ^= NumArgs.GetHashCode();
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (Name.Length != 0) {
- output.WriteRawTag(10);
- output.WriteString(Name);
- }
- if (Id != 0) {
- output.WriteRawTag(16);
- output.WriteInt32(Id);
- }
- if (NumArgs != 0) {
- output.WriteRawTag(24);
- output.WriteInt32(NumArgs);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (Name.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
- }
- if (Id != 0) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
- }
- if (NumArgs != 0) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumArgs);
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(RpcInvocationHeader other) {
- if (other == null) {
- return;
- }
- if (other.Name.Length != 0) {
- Name = other.Name;
- }
- if (other.Id != 0) {
- Id = other.Id;
- }
- if (other.NumArgs != 0) {
- NumArgs = other.NumArgs;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- input.SkipLastField();
- break;
- case 10: {
- Name = input.ReadString();
- break;
- }
- case 16: {
- Id = input.ReadInt32();
- break;
- }
- case 24: {
- NumArgs = input.ReadInt32();
- break;
- }
- }
- }
- }
-
-}
-
-public sealed partial class RpcInvocationResultHeader : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationResultHeader());
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationResultHeader() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() {
- id_ = other.id_;
- hasResult_ = other.hasResult_;
- error_ = other.error_;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public RpcInvocationResultHeader Clone() {
- return new RpcInvocationResultHeader(this);
- }
-
- /// Field number for the "Id" field.
- public const int IdFieldNumber = 1;
- private int id_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int Id {
- get { return id_; }
- set {
- id_ = value;
- }
- }
-
- /// Field number for the "HasResult" field.
- public const int HasResultFieldNumber = 2;
- private bool hasResult_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool HasResult {
- get { return hasResult_; }
- set {
- hasResult_ = value;
- }
- }
-
- /// Field number for the "Error" field.
- public const int ErrorFieldNumber = 3;
- private string error_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string Error {
- get { return error_; }
- set {
- error_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as RpcInvocationResultHeader);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(RpcInvocationResultHeader other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if (Id != other.Id) return false;
- if (HasResult != other.HasResult) return false;
- if (Error != other.Error) return false;
- return true;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (Id != 0) hash ^= Id.GetHashCode();
- if (HasResult != false) hash ^= HasResult.GetHashCode();
- if (Error.Length != 0) hash ^= Error.GetHashCode();
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (Id != 0) {
- output.WriteRawTag(8);
- output.WriteInt32(Id);
- }
- if (HasResult != false) {
- output.WriteRawTag(16);
- output.WriteBool(HasResult);
- }
- if (Error.Length != 0) {
- output.WriteRawTag(26);
- output.WriteString(Error);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (Id != 0) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
- }
- if (HasResult != false) {
- size += 1 + 1;
- }
- if (Error.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(Error);
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(RpcInvocationResultHeader other) {
- if (other == null) {
- return;
- }
- if (other.Id != 0) {
- Id = other.Id;
- }
- if (other.HasResult != false) {
- HasResult = other.HasResult;
- }
- if (other.Error.Length != 0) {
- Error = other.Error;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- input.SkipLastField();
- break;
- case 8: {
- Id = input.ReadInt32();
- break;
- }
- case 16: {
- HasResult = input.ReadBool();
- break;
- }
- case 26: {
- Error = input.ReadString();
- break;
- }
- }
- }
- }
-
-}
-
-public sealed partial class PrimitiveValue : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PrimitiveValue());
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PrimitiveValue() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PrimitiveValue(PrimitiveValue other) : this() {
- switch (other.OneofCase) {
- case OneofOneofCase.Int32Value:
- Int32Value = other.Int32Value;
- break;
- case OneofOneofCase.StringValue:
- StringValue = other.StringValue;
- break;
- }
-
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PrimitiveValue Clone() {
- return new PrimitiveValue(this);
- }
-
- /// Field number for the "Int32Value" field.
- public const int Int32ValueFieldNumber = 1;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int Int32Value {
- get { return oneofCase_ == OneofOneofCase.Int32Value ? (int) oneof_ : 0; }
- set {
- oneof_ = value;
- oneofCase_ = OneofOneofCase.Int32Value;
- }
- }
-
- /// Field number for the "StringValue" field.
- public const int StringValueFieldNumber = 2;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string StringValue {
- get { return oneofCase_ == OneofOneofCase.StringValue ? (string) oneof_ : ""; }
- set {
- oneof_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- oneofCase_ = OneofOneofCase.StringValue;
- }
- }
-
- private object oneof_;
- /// Enum of possible cases for the "oneof_" oneof.
- public enum OneofOneofCase {
- None = 0,
- Int32Value = 1,
- StringValue = 2,
- }
- private OneofOneofCase oneofCase_ = OneofOneofCase.None;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public OneofOneofCase OneofCase {
- get { return oneofCase_; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void ClearOneof() {
- oneofCase_ = OneofOneofCase.None;
- oneof_ = null;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as PrimitiveValue);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(PrimitiveValue other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if (Int32Value != other.Int32Value) return false;
- if (StringValue != other.StringValue) return false;
- if (OneofCase != other.OneofCase) return false;
- return true;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (oneofCase_ == OneofOneofCase.Int32Value) hash ^= Int32Value.GetHashCode();
- if (oneofCase_ == OneofOneofCase.StringValue) hash ^= StringValue.GetHashCode();
- hash ^= (int) oneofCase_;
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (oneofCase_ == OneofOneofCase.Int32Value) {
- output.WriteRawTag(8);
- output.WriteInt32(Int32Value);
- }
- if (oneofCase_ == OneofOneofCase.StringValue) {
- output.WriteRawTag(18);
- output.WriteString(StringValue);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (oneofCase_ == OneofOneofCase.Int32Value) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(Int32Value);
- }
- if (oneofCase_ == OneofOneofCase.StringValue) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(StringValue);
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(PrimitiveValue other) {
- if (other == null) {
- return;
- }
- switch (other.OneofCase) {
- case OneofOneofCase.Int32Value:
- Int32Value = other.Int32Value;
- break;
- case OneofOneofCase.StringValue:
- StringValue = other.StringValue;
- break;
- }
-
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- input.SkipLastField();
- break;
- case 8: {
- Int32Value = input.ReadInt32();
- break;
- }
- case 18: {
- StringValue = input.ReadString();
- break;
- }
- }
- }
- }
-
-}
-
-public sealed partial class PersonMessage : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PersonMessage());
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::RpcInvocationReflection.Descriptor.MessageTypes[4]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PersonMessage() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PersonMessage(PersonMessage other) : this() {
- name_ = other.name_;
- age_ = other.age_;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public PersonMessage Clone() {
- return new PersonMessage(this);
- }
-
- /// Field number for the "Name" field.
- public const int NameFieldNumber = 1;
- private string name_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string Name {
- get { return name_; }
- set {
- name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- /// Field number for the "Age" field.
- public const int AgeFieldNumber = 2;
- private long age_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public long Age {
- get { return age_; }
- set {
- age_ = value;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as PersonMessage);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(PersonMessage other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if (Name != other.Name) return false;
- if (Age != other.Age) return false;
- return true;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (Name.Length != 0) hash ^= Name.GetHashCode();
- if (Age != 0L) hash ^= Age.GetHashCode();
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (Name.Length != 0) {
- output.WriteRawTag(10);
- output.WriteString(Name);
- }
- if (Age != 0L) {
- output.WriteRawTag(16);
- output.WriteInt64(Age);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (Name.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
- }
- if (Age != 0L) {
- size += 1 + pb::CodedOutputStream.ComputeInt64Size(Age);
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(PersonMessage other) {
- if (other == null) {
- return;
- }
- if (other.Name.Length != 0) {
- Name = other.Name;
- }
- if (other.Age != 0L) {
- Age = other.Age;
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- input.SkipLastField();
- break;
- case 10: {
- Name = input.ReadString();
- break;
- }
- case 16: {
- Age = input.ReadInt64();
- break;
- }
- }
- }
- }
-
-}
-
-#endregion
-
-
-#endregion Designer generated code
diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.proto b/samples/SocketsSample/Protobuf/RpcInvocation.proto
deleted file mode 100644
index f63ad75a05..0000000000
--- a/samples/SocketsSample/Protobuf/RpcInvocation.proto
+++ /dev/null
@@ -1,30 +0,0 @@
-syntax = "proto3";
-
-message RpcMessageKind {
- enum Kind { Result = 0; Invocation = 1; }
- Kind MessageKind = 1;
-}
-
-message RpcInvocationHeader {
- string Name = 1;
- int32 Id = 2;
- int32 NumArgs = 3;
-}
-
-message RpcInvocationResultHeader {
- int32 Id = 1;
- bool HasResult = 2;
- string Error = 3;
-}
-
-message PrimitiveValue {
- oneof oneof_ {
- int32 Int32Value = 1;
- string StringValue = 2;
- }
-}
-
-message PersonMessage {
- string Name = 1;
- int64 Age = 2;
-}
\ No newline at end of file
diff --git a/samples/SocketsSample/ProtobufSerializer.cs b/samples/SocketsSample/ProtobufSerializer.cs
deleted file mode 100644
index 3b1aaec672..0000000000
--- a/samples/SocketsSample/ProtobufSerializer.cs
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using Google.Protobuf;
-using SocketsSample.Hubs;
-
-namespace SocketsSample
-{
- public class ProtobufSerializer
- {
- public object GetValue(CodedInputStream inputStream, Type type)
- {
- if (type == typeof(Person))
- {
- var value = new PersonMessage();
- inputStream.ReadMessage(value);
-
- return new Person { Name = value.Name, Age = value.Age };
- }
-
- throw new InvalidOperationException("(Deserialize) Unknown type.");
- }
-
- public IMessage GetMessage(object value)
- {
- Person person = value as Person;
- if (person != null)
- {
- return new PersonMessage { Name = person.Name, Age = person.Age };
- }
-
- throw new InvalidOperationException("(Serialize) Unknown type.");
- }
- }
-}
diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs
index b8f5f3feb4..93b0232efc 100644
--- a/samples/SocketsSample/Startup.cs
+++ b/samples/SocketsSample/Startup.cs
@@ -3,12 +3,10 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
-using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints;
using SocketsSample.Hubs;
-using SocketsSample.Protobuf;
namespace SocketsSample
{
@@ -18,21 +16,12 @@ namespace SocketsSample
// For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
{
- services.AddSingleton();
- services.AddSingleton();
-
services.AddSockets();
- services.AddSignalR(options =>
- {
- options.RegisterInvocationAdapter("protobuf");
- options.RegisterInvocationAdapter("line");
- });
+ services.AddSignalR();
// .AddRedis();
services.AddEndPoint();
-
- services.AddSingleton();
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs b/src/Common/IOutputExtensions.cs
similarity index 100%
rename from src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs
rename to src/Common/IOutputExtensions.cs
diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
index c1b946e9e1..fa7bad1615 100644
--- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
@@ -1,19 +1,20 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.Diagnostics;
-using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
+using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client
{
@@ -22,17 +23,14 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private readonly IConnection _connection;
- private readonly IInvocationAdapter _adapter;
+ private readonly IHubProtocol _protocol;
private readonly HubBinder _binder;
private HttpClient _httpClient;
- private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
-
- // We need to ensure pending calls added after a connection failure don't hang. Right now the easiest thing to do is lock.
private readonly object _pendingCallsLock = new object();
+ private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
private readonly Dictionary _pendingCalls = new Dictionary();
-
private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary();
private int _nextId = 0;
@@ -50,27 +48,33 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
public HubConnection(Uri url)
- : this(new Connection(url), new JsonNetInvocationAdapter(), null)
+ : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), null)
{ }
public HubConnection(Uri url, ILoggerFactory loggerFactory)
- : this(new Connection(url), new JsonNetInvocationAdapter(), loggerFactory)
+ : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), loggerFactory)
{ }
- public HubConnection(Uri url, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
- : this(new Connection(url, loggerFactory), adapter, loggerFactory)
+ // These are only really needed for tests now...
+ public HubConnection(IConnection connection, ILoggerFactory loggerFactory)
+ : this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory)
{ }
- public HubConnection(IConnection connection, IInvocationAdapter adapter, ILoggerFactory loggerFactory)
+ public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
+ if (protocol == null)
+ {
+ throw new ArgumentNullException(nameof(protocol));
+ }
+
_connection = connection;
_binder = new HubBinder(this);
- _adapter = adapter;
+ _protocol = protocol;
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger();
_connection.Received += OnDataReceived;
@@ -100,6 +104,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public async Task DisposeAsync()
{
await _connection.DisposeAsync();
+
_httpClient?.Dispose();
}
@@ -118,81 +123,80 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Task
-
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj
index e60402fb46..551cab5dd5 100644
--- a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj
+++ b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj
@@ -13,7 +13,6 @@
-
diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs
index d744de4130..cb847095a6 100644
--- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs
@@ -5,53 +5,80 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
-using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
-using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
+using Newtonsoft.Json;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable
{
+ private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
+
private readonly ConnectionList _connections = new ConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary _groups = new ConcurrentDictionary();
- private readonly InvocationAdapterRegistry _registry;
private readonly ConnectionMultiplexer _redisServerConnection;
private readonly ISubscriber _bus;
- private readonly ILoggerFactory _loggerFactory;
+ private readonly ILogger _logger;
private readonly RedisOptions _options;
- public RedisHubLifetimeManager(InvocationAdapterRegistry registry,
- ILoggerFactory loggerFactory,
+ // This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
+ private readonly JsonSerializer _serializer = new JsonSerializer
+ {
+ // We need to serialize objects "full-fidelity", even if it is noisy, so we preserve the original types
+ TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
+ TypeNameHandling = TypeNameHandling.All,
+ Formatting = Formatting.None
+ };
+
+ private long _nextInvocationId = 0;
+
+ public RedisHubLifetimeManager(ILogger> logger,
IOptions options)
{
- _loggerFactory = loggerFactory;
- _registry = registry;
+ _logger = logger;
_options = options.Value;
- var writer = new LoggerTextWriter(loggerFactory.CreateLogger>());
+ var writer = new LoggerTextWriter(logger);
+ _logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e))));
_redisServerConnection = _options.Connect(writer);
+ if (_redisServerConnection.IsConnected)
+ {
+ _logger.LogInformation("Connected to redis");
+ }
+ else
+ {
+ // TODO: We could support reconnecting, like old SignalR does.
+ throw new InvalidOperationException("Connection to redis failed.");
+ }
_bus = _redisServerConnection.GetSubscriber();
- var previousBroadcastTask = TaskCache.CompletedTask;
+ var previousBroadcastTask = Task.CompletedTask;
- _bus.Subscribe(typeof(THub).FullName, async (c, data) =>
+ var channelName = typeof(THub).FullName;
+ _logger.LogInformation("Subscribing to channel: {channel}", channelName);
+ _bus.Subscribe(channelName, async (c, data) =>
{
await previousBroadcastTask;
+ _logger.LogTrace("Received message from redis channel {channel}", channelName);
+
+ var message = DeserializeMessage(data);
+
+ // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List(_connections.Count);
foreach (var connection in _connections)
{
- tasks.Add(WriteAsync(connection, data));
+ tasks.Add(WriteAsync(connection, message));
}
previousBroadcastTask = Task.WhenAll(tasks);
@@ -60,80 +87,68 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override Task InvokeAllAsync(string methodName, object[] args)
{
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName, message);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
}
- private async Task PublishAsync(string channel, InvocationDescriptor message)
+ private async Task PublishAsync(string channel, HubMessage hubMessage)
{
- // TODO: What format??
- var invocationAdapter = _registry.GetInvocationAdapter("json");
-
- // BAD
- using (var ms = new MemoryStream())
+ byte[] payload;
+ using (var stream = new MemoryStream())
+ using (var writer = new JsonTextWriter(new StreamWriter(stream)))
{
- await invocationAdapter.WriteMessageAsync(message, ms);
-
- await _bus.PublishAsync(channel, ms.ToArray());
+ _serializer.Serialize(writer, hubMessage);
+ await writer.FlushAsync();
+ payload = stream.ToArray();
}
+
+ _logger.LogTrace("Publishing message to redis channel {channel}", channel);
+ await _bus.PublishAsync(channel, payload);
}
public override Task OnConnectedAsync(Connection connection)
{
- var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet());
- var connectionTask = TaskCache.CompletedTask;
- var userTask = TaskCache.CompletedTask;
+ var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet());
+ var connectionTask = Task.CompletedTask;
+ var userTask = Task.CompletedTask;
_connections.Add(connection);
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
- var previousConnectionTask = TaskCache.CompletedTask;
+ var previousConnectionTask = Task.CompletedTask;
+ _logger.LogInformation("Subscribing to connection channel: {channel}", connectionChannel);
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
await previousConnectionTask;
- previousConnectionTask = WriteAsync(connection, data);
+ var message = DeserializeMessage(data);
+
+ previousConnectionTask = WriteAsync(connection, message);
});
@@ -142,14 +157,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
redisSubscriptions.Add(userChannel);
- var previousUserTask = TaskCache.CompletedTask;
+ var previousUserTask = Task.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{
await previousUserTask;
- previousUserTask = WriteAsync(connection, data);
+ var message = DeserializeMessage(data);
+
+ previousUserTask = WriteAsync(connection, message);
});
}
@@ -162,16 +179,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var tasks = new List();
- var redisSubscriptions = connection.Metadata.Get>("redis_subscriptions");
+ var redisSubscriptions = connection.Metadata.Get>(RedisSubscriptionsMetadataName);
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
+ _logger.LogInformation("Unsubscribing from channel: {channel}", subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
- var groupNames = connection.Metadata.Get>("group");
+ var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
@@ -190,7 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
- var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet());
+ var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet());
lock (groupNames)
{
@@ -210,8 +228,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
- var previousTask = TaskCache.CompletedTask;
+ var previousTask = Task.CompletedTask;
+ _logger.LogInformation("Subscribing to group channel: {channel}", groupChannel);
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
// Since this callback is async, we await the previous task then
@@ -219,10 +238,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
// want to do concurrent writes to the outgoing connections
await previousTask;
+ var message = DeserializeMessage(data);
+
var tasks = new List(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
- tasks.Add(WriteAsync(groupConnection, data));
+ tasks.Add(WriteAsync(groupConnection, message));
}
previousTask = Task.WhenAll(tasks);
@@ -244,7 +265,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
- var groupNames = connection.Metadata.Get>("group");
+ var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
lock (groupNames)
@@ -260,6 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
if (group.Connections.Count == 0)
{
+ _logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
@@ -275,9 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose();
}
- private async Task WriteAsync(Connection connection, byte[] data)
+ private async Task WriteAsync(Connection connection, HubMessage hubMessage)
{
- var message = new Message(data, MessageType.Text, endOfMessage: true);
+ var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol);
+ var data = await protocol.WriteToArrayAsync(hubMessage);
+ var message = new Message(data, protocol.MessageType, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
@@ -288,6 +312,23 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
+ private string GetInvocationId()
+ {
+ var invocationId = Interlocked.Increment(ref _nextInvocationId);
+ return invocationId.ToString();
+ }
+
+ private HubMessage DeserializeMessage(RedisValue data)
+ {
+ HubMessage message;
+ using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
+ {
+ message = (HubMessage)_serializer.Deserialize(reader);
+ }
+
+ return message;
+ }
+
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs
index a2842441f0..c8b83fe4a2 100644
--- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs
+++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs
@@ -3,43 +3,37 @@
using System;
using System.Collections.Generic;
-using System.IO;
-using System.IO.Pipelines;
+using System.Threading;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
-using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager : HubLifetimeManager
{
+ private long _nextInvocationId = 0;
private readonly ConnectionList _connections = new ConnectionList();
- private readonly InvocationAdapterRegistry _registry;
-
- public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
- {
- _registry = registry;
- }
public override Task AddGroupAsync(Connection connection, string groupName)
{
- var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet());
+ var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet());
lock (groups)
{
groups.Add(groupName);
}
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
{
- var groups = connection.Metadata.Get>("groups");
+ var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups);
if (groups == null)
{
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
lock (groups)
@@ -47,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR
groups.Remove(groupName);
}
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
public override Task InvokeAllAsync(string methodName, object[] args)
@@ -58,11 +52,7 @@ namespace Microsoft.AspNetCore.SignalR
private Task InvokeAllWhere(string methodName, object[] args, Func include)
{
var tasks = new List(_connections.Count);
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _connections)
@@ -72,9 +62,7 @@ namespace Microsoft.AspNetCore.SignalR
continue;
}
- var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType"));
-
- tasks.Add(WriteAsync(connection, invocationAdapter, message));
+ tasks.Add(WriteAsync(connection, message));
}
return Task.WhenAll(tasks);
@@ -84,22 +72,16 @@ namespace Microsoft.AspNetCore.SignalR
{
var connection = _connections[connectionId];
- var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType"));
+ var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
- var message = new InvocationDescriptor
- {
- Method = methodName,
- Arguments = args
- };
-
- return WriteAsync(connection, invocationAdapter, message);
+ return WriteAsync(connection, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
- var groups = connection.Metadata.Get>("groups");
+ var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups);
return groups?.Contains(groupName) == true;
});
}
@@ -115,21 +97,20 @@ namespace Microsoft.AspNetCore.SignalR
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
- private static async Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocation)
+ private async Task WriteAsync(Connection connection, HubMessage hubMessage)
{
- var stream = new MemoryStream();
- await invocationAdapter.WriteMessageAsync(invocation, stream);
-
- var message = new Message(stream.ToArray(), MessageType.Text, endOfMessage: true);
+ var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol);
+ var payload = await protocol.WriteToArrayAsync(hubMessage);
+ var message = new Message(payload, protocol.MessageType, endOfMessage: true);
while (await connection.Transport.Output.WaitToWriteAsync())
{
@@ -139,5 +120,11 @@ namespace Microsoft.AspNetCore.SignalR
}
}
}
+
+ private string GetInvocationId()
+ {
+ var invocationId = Interlocked.Increment(ref _nextInvocationId);
+ return invocationId.ToString();
+ }
}
}
diff --git a/src/Microsoft.AspNetCore.SignalR/Hub.cs b/src/Microsoft.AspNetCore.SignalR/Hub.cs
index 5ff206d408..3ecc69bf51 100644
--- a/src/Microsoft.AspNetCore.SignalR/Hub.cs
+++ b/src/Microsoft.AspNetCore.SignalR/Hub.cs
@@ -3,7 +3,6 @@
using System;
using System.Threading.Tasks;
-using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR
{
@@ -62,12 +61,12 @@ namespace Microsoft.AspNetCore.SignalR
public virtual Task OnConnectedAsync()
{
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
public virtual Task OnDisconnectedAsync(Exception exception)
{
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
protected virtual void Dispose(bool disposing)
diff --git a/src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs
similarity index 54%
rename from src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs
rename to src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs
index fc0c1bb5ca..702a762c01 100644
--- a/src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs
+++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs
@@ -1,15 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Threading.Tasks;
-
namespace Microsoft.AspNetCore.SignalR
{
- public abstract class InvocationMessage
+ public static class HubConnectionMetadataNames
{
- public string Id { get; set; }
+ public static readonly string HubProtocol = nameof(HubProtocol);
+ public static readonly string Groups = nameof(Groups);
}
}
diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs
index bbe5f6717c..82af05346f 100644
--- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs
+++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs
@@ -1,13 +1,14 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
-using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
@@ -19,12 +20,12 @@ namespace Microsoft.AspNetCore.SignalR
public class HubEndPoint : HubEndPoint where THub : Hub
{
public HubEndPoint(HubLifetimeManager lifetimeManager,
+ IHubProtocolResolver protocolResolver,
IHubContext hubContext,
- InvocationAdapterRegistry registry,
IOptions>> endPointOptions,
ILogger> logger,
IServiceScopeFactory serviceScopeFactory)
- : base(lifetimeManager, hubContext, registry, endPointOptions, logger, serviceScopeFactory)
+ : base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory)
{
}
}
@@ -36,19 +37,19 @@ namespace Microsoft.AspNetCore.SignalR
private readonly HubLifetimeManager _lifetimeManager;
private readonly IHubContext _hubContext;
private readonly ILogger> _logger;
- private readonly InvocationAdapterRegistry _registry;
private readonly IServiceScopeFactory _serviceScopeFactory;
+ private readonly IHubProtocolResolver _protocolResolver;
public HubEndPoint(HubLifetimeManager lifetimeManager,
+ IHubProtocolResolver protocolResolver,
IHubContext hubContext,
- InvocationAdapterRegistry registry,
IOptions>> endPointOptions,
ILogger> logger,
IServiceScopeFactory serviceScopeFactory)
{
+ _protocolResolver = protocolResolver;
_lifetimeManager = lifetimeManager;
_hubContext = hubContext;
- _registry = registry;
_logger = logger;
_serviceScopeFactory = serviceScopeFactory;
@@ -59,6 +60,11 @@ namespace Microsoft.AspNetCore.SignalR
{
try
{
+ // Resolve the Hub Protocol for the connection and store it in metadata
+ // Other components, outside the Hub, may need to know what protocol is in use
+ // for a particular connection, so we store it here.
+ connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection);
+
await _lifetimeManager.OnConnectedAsync(connection);
await RunHubAsync(connection);
}
@@ -140,14 +146,13 @@ namespace Microsoft.AspNetCore.SignalR
private async Task DispatchMessagesAsync(Connection connection)
{
- var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType"));
-
// We use these for error handling. Since we dispatch multiple hub invocations
// in parallel, we need a way to communicate failure back to the main processing loop. The
// cancellation token is used to stop reading from the channel, the tcs
// is used to get the exception so we can bubble it up the stack
var cts = new CancellationTokenSource();
- var tcs = new TaskCompletionSource();
+ var completion = new TaskCompletionSource();
+ var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol);
try
{
@@ -155,100 +160,94 @@ namespace Microsoft.AspNetCore.SignalR
{
while (connection.Transport.Input.TryRead(out var incomingMessage))
{
- InvocationDescriptor invocationDescriptor;
- var inputStream = new MemoryStream(incomingMessage.Payload);
+ var hubMessage = protocol.ParseMessage(incomingMessage.Payload, this);
- // TODO: Handle receiving InvocationResultDescriptor
- invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor;
-
- // Is there a better way of detecting that a connection was closed?
- if (invocationDescriptor == null)
+ switch (hubMessage)
{
- break;
- }
+ case InvocationMessage invocationMessage:
+ if (_logger.IsEnabled(LogLevel.Debug))
+ {
+ _logger.LogDebug("Received hub invocation: {invocation}", invocationMessage);
+ }
- if (_logger.IsEnabled(LogLevel.Debug))
- {
- _logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor);
- }
+ // Don't wait on the result of execution, continue processing other
+ // incoming messages on this connection.
+ var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion);
+ break;
- // Don't wait on the result of execution, continue processing other
- // incoming messages on this connection.
- var ignore = ProcessInvocation(connection, invocationAdapter, invocationDescriptor, cts, tcs);
+ // Other kind of message we weren't expecting
+ default:
+ _logger.LogError("Received unsupported message of type '{messageType}'", hubMessage.GetType().FullName);
+ throw new NotSupportedException($"Received unsupported message: {hubMessage}");
+ }
}
}
}
catch (OperationCanceledException)
{
// Await the task so the exception bubbles up to the caller
- await tcs.Task;
+ await completion.Task;
}
}
private async Task ProcessInvocation(Connection connection,
- IInvocationAdapter invocationAdapter,
- InvocationDescriptor invocationDescriptor,
- CancellationTokenSource cts,
- TaskCompletionSource tcs)
+ IHubProtocol protocol,
+ InvocationMessage invocationMessage,
+ CancellationTokenSource dispatcherCancellation,
+ TaskCompletionSource dispatcherCompletion)
{
try
{
// If an unexpected exception occurs then we want to kill the entire connection
// by ending the processing loop
- await Execute(connection, invocationAdapter, invocationDescriptor);
+ await Execute(connection, protocol, invocationMessage);
}
catch (Exception ex)
{
// Set the exception on the task completion source
- tcs.TrySetException(ex);
+ dispatcherCompletion.TrySetException(ex);
// Cancel reading operation
- cts.Cancel();
+ dispatcherCancellation.Cancel();
}
}
- private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor)
+ private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage)
{
- InvocationResultDescriptor result;
HubMethodDescriptor descriptor;
- if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor))
+ if (!_methods.TryGetValue(invocationMessage.Target, out descriptor))
{
- result = await Invoke(descriptor, connection, invocationDescriptor);
+ // Send an error to the client. Then let the normal completion process occur
+ _logger.LogError("Unknown hub method '{method}'", invocationMessage.Target);
+ await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
}
else
{
- // If there's no method then return a failed response for this request
- result = new InvocationResultDescriptor
- {
- Id = invocationDescriptor.Id,
- Error = $"Unknown hub method '{invocationDescriptor.Method}'"
- };
-
- _logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method);
- }
-
- // TODO: Pool memory
- var outStream = new MemoryStream();
- await invocationAdapter.WriteMessageAsync(result, outStream);
-
- var outMessage = new Message(outStream.ToArray(), MessageType.Text, endOfMessage: true);
-
- while (await connection.Transport.Output.WaitToWriteAsync())
- {
- if (connection.Transport.Output.TryWrite(outMessage))
- {
- break;
- }
+ var result = await Invoke(descriptor, connection, invocationMessage);
+ await SendMessageAsync(connection, protocol, result);
}
}
- private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor)
+ private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage)
{
- var invocationResult = new InvocationResultDescriptor
- {
- Id = invocationDescriptor.Id
- };
+ var payload = await protocol.WriteToArrayAsync(hubMessage);
+ var message = new Message(payload, protocol.MessageType, endOfMessage: true);
+ while (await connection.Transport.Output.WaitToWriteAsync())
+ {
+ if (connection.Transport.Output.TryWrite(message))
+ {
+ return;
+ }
+ }
+
+ // Output is closed. Cancel this invocation completely
+ _logger.LogWarning("Outbound channel was closed while trying to write hub message");
+ throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
+ }
+
+ private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage)
+ {
var methodExecutor = descriptor.MethodExecutor;
using (var scope = _serviceScopeFactory.CreateScope())
@@ -265,37 +264,35 @@ namespace Microsoft.AspNetCore.SignalR
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
- await (Task)methodExecutor.Execute(hub, invocationDescriptor.Arguments);
+ await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments);
}
else
{
- result = await methodExecutor.ExecuteAsync(hub, invocationDescriptor.Arguments);
+ result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments);
}
}
else
{
- result = methodExecutor.Execute(hub, invocationDescriptor.Arguments);
+ result = methodExecutor.Execute(hub, invocationMessage.Arguments);
}
- invocationResult.Result = result;
+ return CompletionMessage.WithResult(invocationMessage.InvocationId, result);
}
catch (TargetInvocationException ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
- invocationResult.Error = ex.InnerException.Message;
+ return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
- invocationResult.Error = ex.Message;
+ return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message);
}
finally
{
hubActivator.Release(hub);
}
}
-
- return invocationResult;
}
private void InitializeHub(THub hub, Connection connection)
diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs
new file mode 100644
index 0000000000..2d06b8bcee
--- /dev/null
+++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs
@@ -0,0 +1,18 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
+using Microsoft.AspNetCore.Sockets;
+using Newtonsoft.Json;
+
+namespace Microsoft.AspNetCore.SignalR.Internal
+{
+ public class DefaultHubProtocolResolver : IHubProtocolResolver
+ {
+ public IHubProtocol GetProtocol(Connection connection)
+ {
+ // TODO: Allow customization of this serializer!
+ return new JsonHubProtocol(new JsonSerializer());
+ }
+ }
+}
diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs
new file mode 100644
index 0000000000..c9627d0d59
--- /dev/null
+++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs
@@ -0,0 +1,13 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
+using Microsoft.AspNetCore.Sockets;
+
+namespace Microsoft.AspNetCore.SignalR.Internal
+{
+ public interface IHubProtocolResolver
+ {
+ IHubProtocol GetProtocol(Connection connection);
+ }
+}
diff --git a/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs b/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs
deleted file mode 100644
index 112b8d9345..0000000000
--- a/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using Microsoft.Extensions.DependencyInjection;
-using Microsoft.Extensions.Options;
-
-namespace Microsoft.AspNetCore.SignalR
-{
- public class InvocationAdapterRegistry
- {
- private readonly IServiceProvider _serviceProvider;
- private readonly SignalROptions _options;
-
- public InvocationAdapterRegistry(IOptions options, IServiceProvider serviceProvider)
- {
- _options = options.Value;
- _serviceProvider = serviceProvider;
- }
-
- public IInvocationAdapter GetInvocationAdapter(string format)
- {
- Type type;
- if (_options._invocationMappings.TryGetValue(format, out type))
- {
- return _serviceProvider.GetRequiredService(type) as IInvocationAdapter;
- }
-
- return null;
- }
- }
-}
\ No newline at end of file
diff --git a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
index 83080985e8..aedd17a019 100644
--- a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
+++ b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
@@ -14,7 +14,6 @@
-
diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs
index bceb79c04f..71a5f022ef 100644
--- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs
+++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs
@@ -1,9 +1,8 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-using System;
using Microsoft.AspNetCore.SignalR;
-using Microsoft.Extensions.Options;
+using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.Extensions.DependencyInjection
{
@@ -13,26 +12,13 @@ namespace Microsoft.Extensions.DependencyInjection
{
services.AddSockets();
services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>));
+ services.AddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver));
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
- services.AddSingleton, SignalROptionsSetup>();
- services.AddSingleton();
- services.AddSingleton();
services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>));
services.AddRouting();
return new SignalRBuilder(services);
}
-
- public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action setupAction)
- {
- return services.AddSignalR().AddSignalROptions(setupAction);
- }
-
- public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action setupAction)
- {
- builder.Services.Configure(setupAction);
- return builder;
- }
}
}
diff --git a/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs b/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs
deleted file mode 100644
index 65a00b85ca..0000000000
--- a/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Threading.Tasks;
-
-namespace Microsoft.AspNetCore.SignalR
-{
- public class SignalROptions
- {
- internal readonly Dictionary _invocationMappings = new Dictionary();
-
- public void RegisterInvocationAdapter(string format) where TInvocationAdapter : IInvocationAdapter
- {
- _invocationMappings[format] = typeof(TInvocationAdapter);
- }
- }
-}
diff --git a/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs b/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs
deleted file mode 100644
index 546d621437..0000000000
--- a/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright (c) .NET Foundation. All rights reserved.
-// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-
-using System;
-using Microsoft.AspNetCore.SignalR;
-using Microsoft.Extensions.Options;
-
-namespace Microsoft.Extensions.DependencyInjection
-{
- public class SignalROptionsSetup : IConfigureOptions
- {
- public void Configure(SignalROptions options)
- {
- options.RegisterInvocationAdapter("json");
- }
- }
-}
\ No newline at end of file
diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs
index 6d69d094ed..f3c3576fec 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs
@@ -10,7 +10,6 @@ using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
-using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@@ -58,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
return t;
}).Unwrap();
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
public async Task StopAsync()
diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj
index 66031fa62c..3e236ac65f 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj
+++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj
@@ -17,7 +17,6 @@
-
diff --git a/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs
index 981247100b..14a0d9792c 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@@ -8,7 +8,6 @@ using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
-using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@@ -61,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
return t;
}).Unwrap();
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
private async Task OpenConnection(IChannelConnection application, Uri url, CancellationToken cancellationToken)
diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj
index 0e12b5c4cc..ebb3ced79e 100644
--- a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj
+++ b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj
@@ -12,6 +12,10 @@
false
+
+
+
+
diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs
new file mode 100644
index 0000000000..646673782a
--- /dev/null
+++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs
@@ -0,0 +1,12 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+namespace Microsoft.AspNetCore.Sockets
+{
+ public static class ConnectionMetadataNames
+ {
+ public static readonly string Format = nameof(Format);
+ public static readonly string Transport = nameof(Transport);
+ public static readonly string HttpContext = nameof(HttpContext);
+ }
+}
diff --git a/src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs
similarity index 93%
rename from src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs
rename to src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs
index a83814c326..c09c6118fd 100644
--- a/src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs
+++ b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs
@@ -6,7 +6,7 @@ using Microsoft.AspNetCore.Sockets;
namespace Microsoft.Extensions.DependencyInjection
{
- public static class EndpointDependencyInjectionExtensions
+ public static class EndPointDependencyInjectionExtensions
{
public static IServiceCollection AddEndPoint(this IServiceCollection services) where TEndPoint : EndPoint
{
diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs
index 0be2cf9ad5..1edd8b71be 100644
--- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs
+++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs
@@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Sockets
{
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
- state.Connection.Metadata["transport"] = TransportType.LongPolling;
+ state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
}
@@ -232,13 +232,10 @@ namespace Microsoft.AspNetCore.Sockets
private ConnectionState CreateConnection(HttpContext context)
{
var state = _manager.CreateConnection();
+ var format = (string)context.Request.Query[ConnectionMetadataNames.Format];
state.Connection.User = context.User;
-
- // TODO: this is wrong. + how does the user add their own metadata based on HttpContext
- var formatType = (string)context.Request.Query["formatType"];
- state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
- state.Connection.Metadata[typeof(HttpContext)] = context;
-
+ state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
+ state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format;
return state;
}
@@ -355,7 +352,7 @@ namespace Microsoft.AspNetCore.Sockets
var messages = ParseSendBatch(ref reader, messageFormat);
// REVIEW: Do we want to return a specific status code here if the connection has ended?
- _logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count);
+ _logger.LogDebug("Received batch of {count} message(s) in '/send'", messages.Count);
foreach (var message in messages)
{
while (!state.Application.Output.TryWrite(message))
@@ -379,11 +376,11 @@ namespace Microsoft.AspNetCore.Sockets
connectionState.Connection.User = context.User;
- var transport = connectionState.Connection.Metadata.Get("transport");
+ var transport = connectionState.Connection.Metadata.Get(ConnectionMetadataNames.Transport);
if (transport == null)
{
- connectionState.Connection.Metadata["transport"] = transportType;
+ connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
}
else if (transport != transportType)
{
diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs
index 6e3bd605bb..c73423b87f 100644
--- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs
+++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public async Task DisposeAsync()
{
- Task disposeTask = TaskCache.CompletedTask;
+ Task disposeTask = Task.CompletedTask;
try
{
@@ -69,8 +69,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
Connection.Dispose();
Application.Dispose();
- var applicationTask = ApplicationTask ?? TaskCache.CompletedTask;
- var transportTask = TransportTask ?? TaskCache.CompletedTask;
+ var applicationTask = ApplicationTask ?? Task.CompletedTask;
+ var transportTask = TransportTask ?? Task.CompletedTask;
disposeTask = WaitOnTasks(applicationTask, transportTask);
}
diff --git a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj
index 0247ea2f83..fc9ba09d5f 100644
--- a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj
+++ b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj
@@ -18,7 +18,6 @@
-
diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
index 1a3fe2519b..bd2c8347bf 100644
--- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
@@ -134,7 +134,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// Is this a frame we care about?
if (!frame.Opcode.IsMessage())
{
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}
LogFrame("Receiving", frame);
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
index 88a88c07b7..001437eb3c 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
+++ b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
@@ -4,7 +4,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
-using Microsoft.Extensions.Internal;
namespace Microsoft.Extensions.WebSockets.Internal
{
@@ -135,7 +134,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
connection.ExecuteAsync((frame, _) =>
{
messageHandler(frame);
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}, null);
///
@@ -149,7 +148,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
connection.ExecuteAsync((frame, s) =>
{
messageHandler(frame, s);
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
}, state);
///
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
index 6d15f40632..d886a6890c 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
+++ b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
@@ -17,7 +17,6 @@
-
diff --git a/test/Common/TaskExtensions.cs b/test/Common/TaskExtensions.cs
index 51c7b549e4..10b31649bc 100644
--- a/test/Common/TaskExtensions.cs
+++ b/test/Common/TaskExtensions.cs
@@ -1,7 +1,8 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Tests.Common
@@ -10,36 +11,48 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common
{
private const int DefaultTimeout = 5000;
- public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout)
+ public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
- return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
+ return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
}
- public static async Task OrTimeout(this Task task, TimeSpan timeout)
+ public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(timeout));
if (completed != task)
{
- throw new TimeoutException();
+ throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
await task;
}
- public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout)
+ public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
- return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds));
+ return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber);
}
- public static async Task OrTimeout(this Task task, TimeSpan timeout)
+ public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(timeout));
if (completed != task)
{
- throw new TimeoutException();
+ throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
return await task;
}
+
+ private static string GetMessage(string memberName, string filePath, int? lineNumber)
+ {
+ if (!string.IsNullOrEmpty(memberName))
+ {
+ return $"Operation in {memberName} timed out at {filePath}:{lineNumber}";
+ }
+ else
+ {
+ return "Operation timed out";
+ }
+ }
}
}
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs
index 51cf710546..24e5a3425f 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs
@@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var httpClient = _testServer.CreateClient())
{
- var connection = new HubConnection(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), loggerFactory);
+ var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory);
try
{
await connection.StartAsync(TransportType.LongPolling, httpClient);
@@ -133,13 +133,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
tcs.TrySetResult((string)a[0]);
});
- await connection.Invoke("CallEcho", originalMessage);
+ await connection.Invoke("CallEcho", originalMessage).OrTimeout();
Assert.Equal(originalMessage, await tcs.Task.OrTimeout());
}
finally
{
- await connection.DisposeAsync();
+ await connection.DisposeAsync().OrTimeout();
}
}
}
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs
new file mode 100644
index 0000000000..de53017a80
--- /dev/null
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs
@@ -0,0 +1,156 @@
+using System;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
+using Microsoft.AspNetCore.SignalR.Tests.Common;
+using Microsoft.Extensions.Logging;
+using Newtonsoft.Json;
+using Xunit;
+
+namespace Microsoft.AspNetCore.SignalR.Client.Tests
+{
+ // This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer).
+ // We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks
+ // don't cause problems.
+ public class HubConnectionProtocolTests
+ {
+ [Fact]
+ public async Task InvokeSendsAnInvocationMessage()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ try
+ {
+ await hubConnection.StartAsync();
+
+ var invokeTask = hubConnection.Invoke("Foo", typeof(void));
+
+ var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout();
+
+ Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage);
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+
+ [Fact]
+ public async Task InvokeCompletedWhenCompletionMessageReceived()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ try
+ {
+ await hubConnection.StartAsync();
+
+ var invokeTask = hubConnection.Invoke("Foo", typeof(void));
+
+ await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
+
+ await invokeTask.OrTimeout();
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+
+ [Fact]
+ public async Task InvokeYieldsResultWhenCompletionMessageReceived()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ try
+ {
+ await hubConnection.StartAsync();
+
+ var invokeTask = hubConnection.Invoke("Foo");
+
+ await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout();
+
+ Assert.Equal(42, await invokeTask.OrTimeout());
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+
+ [Fact]
+ public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ try
+ {
+ await hubConnection.StartAsync();
+
+ var invokeTask = hubConnection.Invoke("Foo");
+
+ await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout();
+
+ var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout();
+ Assert.Equal("An error occurred", ex.Message);
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+
+ [Fact]
+ // This will fail (intentionally) when we support streaming!
+ public async Task InvokeFailsWithErrorWhenStreamingItemReceived()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ try
+ {
+ await hubConnection.StartAsync();
+
+ var invokeTask = hubConnection.Invoke("Foo");
+
+ await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, result = 42 }).OrTimeout();
+
+ var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout();
+ Assert.Equal("Streaming method results are not supported", ex.Message);
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+
+ [Fact]
+ public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory());
+ var handlerCalled = new TaskCompletionSource();
+ try
+ {
+ await hubConnection.StartAsync();
+
+ hubConnection.On("Foo", new[] { typeof(int), typeof(string), typeof(float) }, (a) => handlerCalled.TrySetResult(a));
+
+ await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 1, "Foo", 2.0f } }).OrTimeout();
+
+ var results = await handlerCalled.Task.OrTimeout();
+ Assert.Equal(3, results.Length);
+ Assert.Equal(1, results[0]);
+ Assert.Equal("Foo", results[1]);
+ Assert.Equal(2.0f, results[2]);
+ }
+ finally
+ {
+ await hubConnection.DisposeAsync().OrTimeout();
+ await connection.DisposeAsync().OrTimeout();
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
index 3298a26e03..0755c598e4 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
@@ -2,12 +2,13 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
-using System.Collections.Generic;
-using System.IO;
+using System.Buffers;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Client.Tests;
+using Microsoft.AspNetCore.SignalR.Internal;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
@@ -25,14 +26,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public void CannotCreateHubConnectionWithNullUrl()
{
var exception = Assert.Throws(
- () => new HubConnection((Uri)null, Mock.Of(), Mock.Of()));
+ () => new HubConnection((Uri)null, Mock.Of()));
Assert.Equal("url", exception.ParamName);
}
[Fact]
public async Task CanDisposeNotStartedHubConnection()
{
- await new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory())
+ await new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory())
.DisposeAsync();
}
@@ -50,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
- var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
try
{
@@ -81,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
- var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory());
await hubConnection.StartAsync(TransportType.LongPolling, httpClient);
await hubConnection.DisposeAsync();
@@ -110,7 +111,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var exception =
await Assert.ThrowsAsync(async () => await hubConnection.Invoke("test"));
- Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
+ Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
}
[Fact]
@@ -119,12 +120,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var mockConnection = new Mock();
var exception = new InvalidOperationException();
- var mockInvocationAdapter = new Mock();
- mockInvocationAdapter
- .Setup(a => a.WriteMessageAsync(It.IsAny(), It.IsAny(), It.IsAny()))
- .Returns(Task.FromException(exception));
-
- var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
+ var mockProtocol = MockHubProtocol.Throw(exception);
+ var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
await hubConnection.StartAsync(TransportType.All, httpClient: null);
var actualException =
@@ -146,7 +143,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
- var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
try
{
var connectedEventRaisedTcs = new TaskCompletionSource();
@@ -177,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
- var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory());
var closedEventTcs = new TaskCompletionSource();
hubConnection.Closed += e => closedEventTcs.SetResult(e);
@@ -197,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
.Returns(Task.FromResult(null));
- var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null);
await hubConnection.DisposeAsync();
@@ -216,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null))
.Returns(Task.FromResult(null));
- var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null);
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
@@ -235,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
.Callback(() => mockConnection.Raise(c => c.Closed += null, exception))
.Returns(Task.FromResult(null));
- var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory());
+ var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory());
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null);
var invokeTask = hubConnection.Invoke("testMethod", typeof(int));
@@ -250,23 +247,69 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var mockConnection = new Mock();
- var invocationDescriptor = new InvocationDescriptor
- {
- Method = "NonExistingMethod123",
- Arguments = new object[] { true, "arg2", 123 },
- Id = Guid.NewGuid().ToString()
- };
+ var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 });
- var mockInvocationAdapter = new Mock();
- mockInvocationAdapter
- .Setup(a => a.ReadMessageAsync(It.IsAny(), It.IsAny(), It.IsAny()))
- .Returns(Task.FromResult((InvocationMessage)invocationDescriptor));
+ var mockProtocol = MockHubProtocol.ReturnOnParse(invocation);
- var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null);
+ var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null);
mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { }, MessageType.Text });
- mockInvocationAdapter.Verify(a => a.ReadMessageAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once());
+ Assert.Equal(1, mockProtocol.ParseCalls);
+ }
+
+ // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
+ private class MockHubProtocol : IHubProtocol
+ {
+ private HubMessage _parsed;
+ private Exception _error;
+
+ public int ParseCalls { get; private set; } = 0;
+ public int WriteCalls { get; private set; } = 0;
+
+ public MessageType MessageType => MessageType.Text;
+
+ public static MockHubProtocol ReturnOnParse(HubMessage parsed)
+ {
+ return new MockHubProtocol
+ {
+ _parsed = parsed
+ };
+ }
+
+ public static MockHubProtocol Throw(Exception error)
+ {
+ return new MockHubProtocol
+ {
+ _error = error
+ };
+ }
+
+ public HubMessage ParseMessage(ReadOnlySpan input, IInvocationBinder binder)
+ {
+ ParseCalls += 1;
+ if (_error != null)
+ {
+ throw _error;
+ }
+ if (_parsed != null)
+ {
+ return _parsed;
+ }
+
+ throw new InvalidOperationException("No Parsed Message provided");
+ }
+
+ public bool TryWriteMessage(HubMessage message, IOutput output)
+ {
+ WriteCalls += 1;
+
+ if (_error != null)
+ {
+ throw _error;
+ }
+ return true;
+ }
}
}
}
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs
new file mode 100644
index 0000000000..8a5eb49675
--- /dev/null
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs
@@ -0,0 +1,116 @@
+using System;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Threading.Tasks.Channels;
+using Microsoft.AspNetCore.Sockets;
+using Microsoft.AspNetCore.Sockets.Client;
+using Newtonsoft.Json;
+
+namespace Microsoft.AspNetCore.SignalR.Client.Tests
+{
+ internal class TestConnection : IConnection
+ {
+ private TaskCompletionSource _started = new TaskCompletionSource();
+ private TaskCompletionSource _disposed = new TaskCompletionSource();
+
+ private Channel _sentMessages = Channel.CreateUnbounded();
+ private Channel _receivedMessages = Channel.CreateUnbounded();
+
+ private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource();
+ private Task _receiveLoop;
+
+ public event Action Connected;
+ public event Action Received;
+ public event Action Closed;
+
+ public Task Started => _started.Task;
+ public Task Disposed => _disposed.Task;
+ public ReadableChannel SentMessages => _sentMessages.In;
+ public WritableChannel ReceivedMessages => _receivedMessages.Out;
+
+ public TestConnection()
+ {
+ _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token);
+ }
+
+ public Task DisposeAsync()
+ {
+ _disposed.TrySetResult(null);
+ _receiveShutdownToken.Cancel();
+ return _receiveLoop;
+ }
+
+ public async Task SendAsync(byte[] data, MessageType type, CancellationToken cancellationToken)
+ {
+ if(!_started.Task.IsCompleted)
+ {
+ throw new InvalidOperationException("Connection must be started before SendAsync can be called");
+ }
+
+ var message = new Message(data, type, endOfMessage: true);
+ while (await _sentMessages.Out.WaitToWriteAsync(cancellationToken))
+ {
+ if (_sentMessages.Out.TryWrite(message))
+ {
+ return;
+ }
+ }
+ throw new ObjectDisposedException("Unable to send message, underlying channel was closed");
+ }
+
+ public Task StartAsync(ITransportFactory transportFactory, HttpClient httpClient)
+ {
+ _started.TrySetResult(null);
+ Connected?.Invoke();
+ return Task.CompletedTask;
+ }
+
+ public async Task ReadSentTextMessageAsync()
+ {
+ var message = await SentMessages.ReadAsync();
+ if (message.Type != MessageType.Text)
+ {
+ throw new InvalidOperationException($"Unexpected message of type: {message.Type}");
+ }
+ return Encoding.UTF8.GetString(message.Payload);
+ }
+
+ public Task ReceiveJsonMessage(object jsonObject)
+ {
+ var json = JsonConvert.SerializeObject(jsonObject, Formatting.None);
+ var bytes = Encoding.UTF8.GetBytes(json);
+ var message = new Message(bytes, MessageType.Text);
+
+ return _receivedMessages.Out.WriteAsync(message);
+ }
+
+ private async Task ReceiveLoopAsync(CancellationToken token)
+ {
+ try
+ {
+ while (!token.IsCancellationRequested)
+ {
+ while (await _receivedMessages.In.WaitToReadAsync(token))
+ {
+ while (_receivedMessages.In.TryRead(out var message))
+ {
+ Received?.Invoke(message.Payload, message.Type);
+ }
+ }
+ }
+ Closed?.Invoke(null);
+ }
+ catch (OperationCanceledException)
+ {
+ // Do nothing, we were just asked to shut down.
+ Closed?.Invoke(null);
+ }
+ catch (Exception ex)
+ {
+ Closed?.Invoke(ex);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs
new file mode 100644
index 0000000000..6dc6fb4a0d
--- /dev/null
+++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs
@@ -0,0 +1,257 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
+using Newtonsoft.Json;
+using Newtonsoft.Json.Serialization;
+using Xunit;
+
+namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
+{
+ public class JsonHubProtocolTests
+ {
+ public static IEnumerable ProtocolTestData => new[]
+ {
+ new object[] { new InvocationMessage("123", true, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"nonBlocking\":true,\"arguments\":[1,\"Foo\",2.0]}" },
+ new object[] { new InvocationMessage("123", false, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" },
+ new object[] { new InvocationMessage("123", false, "Target", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[true]}" },
+ new object[] { new InvocationMessage("123", false, "Target", new object[] { null }), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[null]}" },
+ new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
+ new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}]}" },
+ new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}]}" },
+ new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}]}" },
+
+ new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":1}" },
+ new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":\"Foo\"}" },
+ new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":2.0}" },
+ new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":true}" },
+ new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":null}" },
+ new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
+ new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
+ new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
+ new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
+
+ new object[] { CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":1}" },
+ new object[] { CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":\"Foo\"}" },
+ new object[] { CompletionMessage.WithResult("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":2.0}" },
+ new object[] { CompletionMessage.WithResult("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":true}" },
+ new object[] { CompletionMessage.WithResult("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":null}" },
+ new object[] { CompletionMessage.WithError("123", "Whoops!"), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"error\":\"Whoops!\"}" },
+ new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" },
+ new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" },
+ new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" },
+ new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" },
+ };
+
+ [Theory]
+ [MemberData(nameof(ProtocolTestData))]
+ public async Task WriteMessage(HubMessage message, bool camelCase, NullValueHandling nullValueHandling, string expectedOutput)
+ {
+ var jsonSerializer = new JsonSerializer
+ {
+ NullValueHandling = nullValueHandling,
+ ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
+ };
+
+ var protocol = new JsonHubProtocol(jsonSerializer);
+ var encoded = await protocol.WriteToArrayAsync(message);
+ var json = Encoding.UTF8.GetString(encoded);
+
+ Assert.Equal(expectedOutput, json);
+ }
+
+ [Theory]
+ [MemberData(nameof(ProtocolTestData))]
+ public void ParseMessage(HubMessage expectedMessage, bool camelCase, NullValueHandling nullValueHandling, string input)
+ {
+ var jsonSerializer = new JsonSerializer
+ {
+ NullValueHandling = nullValueHandling,
+ ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver()
+ };
+
+ var binder = new TestBinder(expectedMessage);
+ var protocol = new JsonHubProtocol(jsonSerializer);
+ var message = protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder);
+
+ Assert.Equal(expectedMessage, message, TestEqualityComparer.Instance);
+ }
+
+ [Theory]
+ [InlineData("", "Error reading JSON.")]
+ [InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
+ [InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
+ [InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
+ [InlineData("[42]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
+ [InlineData("{}", "Missing required property 'type'.")]
+
+ [InlineData("{'type':1}", "Missing required property 'invocationId'.")]
+ [InlineData("{'type':1,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
+ [InlineData("{'type':1,'invocationId':'42','target':42}", "Expected 'target' to be of type String.")]
+ [InlineData("{'type':1,'invocationId':'42','target':'foo'}", "Missing required property 'arguments'.")]
+ [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':{}}", "Expected 'arguments' to be of type Array.")]
+
+ [InlineData("{'type':2}", "Missing required property 'invocationId'.")]
+ [InlineData("{'type':2,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
+ [InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'result'.")]
+
+ [InlineData("{'type':3}", "Missing required property 'invocationId'.")]
+ [InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
+ [InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")]
+
+ [InlineData("{'type':4}", "Unknown message type: 4")]
+ [InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")]
+ public void InvalidMessages(string input, string expectedMessage)
+ {
+ var binder = new TestBinder();
+ var protocol = new JsonHubProtocol(new JsonSerializer());
+ var ex = Assert.Throws(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
+ Assert.Equal(expectedMessage, ex.Message);
+ }
+
+ [Theory]
+ [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")]
+ [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[42, 'foo'],'nonBlocking':42}", "Expected 'nonBlocking' to be of type Boolean.")]
+ [InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")]
+ public void InvalidMessagesWithBinder(string input, string expectedMessage)
+ {
+ var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
+ var protocol = new JsonHubProtocol(new JsonSerializer());
+ var ex = Assert.Throws(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder));
+ Assert.Equal(expectedMessage, ex.Message);
+ }
+
+ private class CustomObject : IEquatable
+ {
+ // Not intended to be a full set of things, just a smattering of sample serializations
+ public string StringProp => "SignalR!";
+
+ public double DoubleProp => 6.2831853071;
+
+ public int IntProp => 42;
+
+ public DateTime DateTimeProp => new DateTime(2017, 4, 11);
+
+ public object NullProp => null;
+
+ public override bool Equals(object obj)
+ {
+ return obj is CustomObject o && Equals(o);
+ }
+
+ public override int GetHashCode()
+ {
+ // This is never used in a hash table
+ return 0;
+ }
+
+ public bool Equals(CustomObject right)
+ {
+ // This allows the comparer below to properly compare the object in the test.
+ return string.Equals(StringProp, right.StringProp, StringComparison.Ordinal) &&
+ DoubleProp == right.DoubleProp &&
+ IntProp == right.IntProp &&
+ DateTime.Equals(DateTimeProp, right.DateTimeProp) &&
+ NullProp == right.NullProp;
+ }
+ }
+
+ // Binder that works based on the expected message argument/result types :)
+ private class TestBinder : IInvocationBinder
+ {
+ private readonly Type[] _paramTypes;
+ private readonly Type _returnType;
+
+ public TestBinder(HubMessage expectedMessage)
+ {
+ switch(expectedMessage)
+ {
+ case InvocationMessage i:
+ _paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray();
+ break;
+ case StreamItemMessage s:
+ _returnType = s.Item?.GetType() ?? typeof(object);
+ break;
+ case CompletionMessage c:
+ _returnType = c.Result?.GetType() ?? typeof(object);
+ break;
+ }
+ }
+
+ public TestBinder() : this(null, null) { }
+ public TestBinder(Type[] paramTypes) : this(paramTypes, null) { }
+ public TestBinder(Type returnType) : this(null, returnType) {}
+ public TestBinder(Type[] paramTypes, Type returnType)
+ {
+ _paramTypes = paramTypes;
+ _returnType = returnType;
+ }
+
+ public Type[] GetParameterTypes(string methodName)
+ {
+ if (_paramTypes != null)
+ {
+ return _paramTypes;
+ }
+ throw new InvalidOperationException("Unexpected binder call");
+ }
+
+ public Type GetReturnType(string invocationId)
+ {
+ if (_returnType != null)
+ {
+ return _returnType;
+ }
+ throw new InvalidOperationException("Unexpected binder call");
+ }
+ }
+
+ private class TestEqualityComparer : IEqualityComparer
+ {
+ public static readonly TestEqualityComparer Instance = new TestEqualityComparer();
+
+ private TestEqualityComparer() { }
+
+ public bool Equals(HubMessage x, HubMessage y)
+ {
+ if (!string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal))
+ {
+ return false;
+ }
+
+ return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y);
+ }
+
+ private bool CompletionMessagesEqual(HubMessage x, HubMessage y)
+ {
+ return x is CompletionMessage left && y is CompletionMessage right &&
+ string.Equals(left.Error, right.Error, StringComparison.Ordinal) &&
+ Equals(left.Result, right.Result) &&
+ left.HasResult == right.HasResult;
+ }
+
+ private bool StreamItemMessagesEqual(HubMessage x, HubMessage y)
+ {
+ return x is StreamItemMessage left && y is StreamItemMessage right &&
+ Equals(left.Item, right.Item);
+ }
+
+ private bool InvocationMessagesEqual(HubMessage x, HubMessage y)
+ {
+ return x is InvocationMessage left && y is InvocationMessage right &&
+ string.Equals(left.Target, right.Target, StringComparison.Ordinal) &&
+ Enumerable.SequenceEqual(left.Arguments, right.Arguments) &&
+ left.NonBlocking == right.NonBlocking;
+ }
+
+ public int GetHashCode(HubMessage obj)
+ {
+ // We never use these in a hash-table
+ return 0;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj
new file mode 100644
index 0000000000..d06c98408d
--- /dev/null
+++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj
@@ -0,0 +1,31 @@
+
+
+
+
+
+ netcoreapp2.0;net46
+ netcoreapp2.0
+
+ true
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
index 84af03faf6..91b51ba380 100644
--- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
@@ -1,14 +1,14 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
+using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
-using Microsoft.Extensions.Internal;
using Moq;
using Xunit;
-using Microsoft.AspNetCore.SignalR.Tests.Common;
namespace Microsoft.AspNetCore.SignalR.Tests
{
@@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var trackDispose = new TrackDispose();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose));
- var endPoint = serviceProvider.GetService>();
+ var endPoint = serviceProvider.GetService>();
using (var client = new TestClient(serviceProvider))
{
@@ -127,10 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(MethodHub.TaskValueMethod)).OrTimeout();
+ var result = (await client.InvokeAsync(nameof(MethodHub.TaskValueMethod)).OrTimeout()).Result;
// json serializer makes this a long
- Assert.Equal(42L, result.Result);
+ Assert.Equal(42L, result);
// kill the connection
client.Dispose();
@@ -150,10 +150,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke("echo", "hello").OrTimeout();
+ var result = (await client.InvokeAsync("echo", "hello").OrTimeout()).Result;
- Assert.Null(result.Error);
- Assert.Equal("hello", result.Result);
+ Assert.Equal("hello", result);
+
+ // kill the connection
+ client.Dispose();
+
+ await endPointTask.OrTimeout();
+ }
+ }
+
+ [Theory]
+ [InlineData(nameof(MethodHub.MethodThatThrows))]
+ [InlineData(nameof(MethodHub.MethodThatYieldsFailedTask))]
+ public async Task HubMethodCanThrowOrYieldFailedTask(string methodName)
+ {
+ var serviceProvider = CreateServiceProvider();
+
+ var endPoint = serviceProvider.GetService>();
+
+ using (var client = new TestClient(serviceProvider))
+ {
+ var endPointTask = endPoint.OnConnectedAsync(client.Connection);
+
+ var result = (await client.InvokeAsync(methodName).OrTimeout());
+
+ Assert.Equal("BOOM!", result.Error);
// kill the connection
client.Dispose();
@@ -173,10 +196,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(MethodHub.ValueMethod)).OrTimeout();
+ var result = (await client.InvokeAsync(nameof(MethodHub.ValueMethod)).OrTimeout()).Result;
// json serializer makes this a long
- Assert.Equal(43L, result.Result);
+ Assert.Equal(43L, result);
// kill the connection
client.Dispose();
@@ -196,9 +219,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(MethodHub.VoidMethod)).OrTimeout();
+ var result = (await client.InvokeAsync(nameof(MethodHub.VoidMethod)).OrTimeout()).Result;
- Assert.Null(result.Result);
+ Assert.Null(result);
// kill the connection
client.Dispose();
@@ -218,9 +241,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout();
+ var result = (await client.InvokeAsync(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout()).Result;
- Assert.Equal("32, 42, m, string", result.Result);
+ Assert.Equal("32, 42, m, string", result);
// kill the connection
client.Dispose();
@@ -240,9 +263,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(InheritedHub.BaseMethod), "string").OrTimeout();
+ var result = (await client.InvokeAsync(nameof(InheritedHub.BaseMethod), "string").OrTimeout()).Result;
- Assert.Equal("string", result.Result);
+ Assert.Equal("string", result);
// kill the connection
client.Dispose();
@@ -262,9 +285,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(InheritedHub.VirtualMethod), 10).OrTimeout();
+ var result = (await client.InvokeAsync(nameof(InheritedHub.VirtualMethod), 10).OrTimeout()).Result;
- Assert.Equal(0L, result.Result);
+ Assert.Equal(0L, result);
// kill the connection
client.Dispose();
@@ -284,7 +307,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- var result = await client.Invoke(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
+ var result = await client.InvokeAsync(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout();
Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error);
@@ -326,15 +349,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
- await firstClient.Invoke(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
+ await firstClient.SendInvocationAsync(nameof(MethodHub.BroadcastMethod), "test").OrTimeout();
foreach (var result in await Task.WhenAll(
- firstClient.Read(),
- secondClient.Read()).OrTimeout())
+ firstClient.Read(),
+ secondClient.Read()).OrTimeout())
{
- Assert.Equal("Broadcast", result.Method);
- Assert.Equal(1, result.Arguments.Length);
- Assert.Equal("test", result.Arguments[0]);
+ var invocation = Assert.IsType(result);
+ Assert.Equal("Broadcast", invocation.Target);
+ Assert.Equal(1, invocation.Arguments.Length);
+ Assert.Equal("test", invocation.Arguments[0]);
}
// kill the connections
@@ -360,23 +384,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
- var result = await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
+ var result = (await firstClient.InvokeAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout()).Result;
+
// check that 'firstConnection' hasn't received the group send
- Assert.Null(result.Id);
+ Assert.Null(firstClient.TryRead());
// check that 'secondConnection' hasn't received the group send
- Assert.Null(await secondClient.TryRead().OrTimeout());
+ Assert.Null(secondClient.TryRead());
- result = await secondClient.Invoke(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout();
- Assert.Null(result.Id);
+ result = (await secondClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout()).Result;
- await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
+ await firstClient.SendInvocationAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout();
// check that 'secondConnection' has received the group send
- var descriptor = await secondClient.Read().OrTimeout();
- Assert.Equal("Send", descriptor.Method);
- Assert.Equal(1, descriptor.Arguments.Length);
- Assert.Equal("test", descriptor.Arguments[0]);
+ var hubMessage = await secondClient.Read().OrTimeout();
+ var invocation = Assert.IsType(hubMessage);
+ Assert.Equal("Send", invocation.Target);
+ Assert.Equal(1, invocation.Arguments.Length);
+ Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@@ -397,7 +422,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
- await client.Invoke(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
+ await client.SendInvocationAsync(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout();
// kill the connection
client.Dispose();
@@ -421,13 +446,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
- await firstClient.Invoke(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
+ await firstClient.SendInvocationAsync(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout();
// check that 'secondConnection' has received the group send
- var result = await secondClient.Read().OrTimeout();
- Assert.Equal("Send", result.Method);
- Assert.Equal(1, result.Arguments.Length);
- Assert.Equal("test", result.Arguments[0]);
+ var hubMessage = await secondClient.Read().OrTimeout();
+ var invocation = Assert.IsType(hubMessage);
+ Assert.Equal("Send", invocation.Target);
+ Assert.Equal(1, invocation.Arguments.Length);
+ Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@@ -452,13 +478,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
- await firstClient.Invoke(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
+ await firstClient.SendInvocationAsync(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout();
// check that 'secondConnection' has received the group send
- var result = await secondClient.Read().OrTimeout();
- Assert.Equal("Send", result.Method);
- Assert.Equal(1, result.Arguments.Length);
- Assert.Equal("test", result.Arguments[0]);
+ var hubMessage = await secondClient.Read().OrTimeout();
+ var invocation = Assert.IsType(hubMessage);
+ Assert.Equal("Send", invocation.Target);
+ Assert.Equal(1, invocation.Arguments.Length);
+ Assert.Equal("test", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
@@ -575,7 +602,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public override Task OnDisconnectedAsync(Exception e)
{
- return TaskCache.CompletedTask;
+ return Task.CompletedTask;
+ }
+
+ public void MethodThatThrows()
+ {
+ throw new InvalidOperationException("BOOM!");
+ }
+
+ public Task MethodThatYieldsFailedTask()
+ {
+ return Task.FromException(new InvalidOperationException("BOOM!"));
}
}
@@ -611,11 +648,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
- private class TestHub : Hub
+ private class DisposeTrackingHub : Hub
{
private TrackDispose _trackDispose;
- public TestHub(TrackDispose trackDispose)
+ public DisposeTrackingHub(TrackDispose trackDispose)
{
_trackDispose = trackDispose;
}
diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs
index 85b9b3ec36..1139ff5295 100644
--- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs
@@ -2,29 +2,30 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
-using System.IO;
+using System.Collections.Generic;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
+using Microsoft.AspNetCore.SignalR.Internal;
+using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
-using Microsoft.Extensions.DependencyInjection;
+using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Tests
{
- public class TestClient : IDisposable
+ public class TestClient : IDisposable, IInvocationBinder
{
private static int _id;
- private IInvocationAdapter _adapter;
+ private IHubProtocol _protocol;
private CancellationTokenSource _cts;
- private TestBinder _binder;
public Connection Connection;
public IChannelConnection Application { get; }
public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task;
- public TestClient(IServiceProvider serviceProvider, string format = "json")
+ public TestClient(IServiceProvider serviceProvider)
{
var transportToApplication = Channel.CreateUnbounded();
var applicationToTransport = Channel.CreateUnbounded();
@@ -33,62 +34,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport);
Connection = new Connection(Guid.NewGuid().ToString(), transport);
- Connection.Metadata["formatType"] = format;
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource();
- var invocationAdapter = serviceProvider.GetService();
- _adapter = invocationAdapter.GetInvocationAdapter(format);
-
- _binder = new TestBinder();
+ _protocol = new JsonHubProtocol(new JsonSerializer());
_cts = new CancellationTokenSource();
}
- public async Task Invoke(string methodName, params object[] args) where T : InvocationMessage
+ public async Task InvokeAsync(string methodName, params object[] args)
{
- await Invoke(methodName, args);
+ var invocationId = await SendInvocationAsync(methodName, args);
- return await Read();
- }
-
- public async Task Invoke(string methodName, params object[] args)
- {
- var stream = new MemoryStream();
- await _adapter.WriteMessageAsync(new InvocationDescriptor
+ while (true)
{
- Arguments = args,
- Method = methodName
- },
- stream);
+ var message = await Read();
- await Application.Output.WriteAsync(new Message(stream.ToArray(), MessageType.Binary, endOfMessage: true));
- }
-
- public async Task Read() where T : InvocationMessage
- {
- while (await Application.Input.WaitToReadAsync(_cts.Token))
- {
- var value = await TryRead();
-
- if (value != null)
+ if (!string.Equals(message.InvocationId, invocationId))
{
- return value;
+ throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
+ }
+
+ if (message == null)
+ {
+ throw new InvalidOperationException("Connection aborted!");
+ }
+
+ switch (message)
+ {
+ case StreamItemMessage result:
+ throw new NotSupportedException("TestClient does not support streaming!");
+ case CompletionMessage completion:
+ return completion;
+ default:
+ throw new NotSupportedException("TestClient does not support receiving invocations!");
}
}
-
- return null;
}
- public async Task TryRead() where T : InvocationMessage
+ public async Task SendInvocationAsync(string methodName, params object[] args)
{
- Message message;
- if (Application.Input.TryRead(out message))
- {
- var value = await _adapter.ReadMessageAsync(new MemoryStream(message.Payload), _binder);
- return value as T;
- }
+ var invocationId = GetInvocationId();
+ var payload = await _protocol.WriteToArrayAsync(new InvocationMessage(invocationId, nonBlocking: false, target: methodName, arguments: args));
+ await Application.Output.WriteAsync(new Message(payload, _protocol.MessageType, endOfMessage: true));
+
+ return invocationId;
+ }
+
+ public async Task Read()
+ {
+ while (true)
+ {
+ var message = TryRead();
+
+ if (message == null)
+ {
+ if (!await Application.Input.WaitToReadAsync())
+ {
+ return null;
+ }
+ }
+ else
+ {
+ return message;
+ }
+ }
+ }
+
+ public HubMessage TryRead()
+ {
+ if (Application.Input.TryRead(out var message))
+ {
+ return _protocol.ParseMessage(message.Payload, this);
+ }
return null;
}
@@ -98,18 +117,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Connection.Dispose();
}
- private class TestBinder : IInvocationBinder
+ private static string GetInvocationId()
{
- public Type[] GetParameterTypes(string methodName)
- {
- // TODO: Possibly support actual client methods
- return new[] { typeof(object) };
- }
+ return Guid.NewGuid().ToString("N");
+ }
- public Type GetReturnType(string invocationId)
- {
- return typeof(object);
- }
+ Type[] IInvocationBinder.GetParameterTypes(string methodName)
+ {
+ // TODO: Possibly support actual client methods
+ return new[] { typeof(object) };
+ }
+
+ Type IInvocationBinder.GetReturnType(string invocationId)
+ {
+ return typeof(object);
}
}
}