From 991c1d851777ab92fe81bbae05c052a38897350b Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Tue, 9 May 2017 12:24:58 -0700 Subject: [PATCH] Implement new Hub Protocol (Part Deux) (#390) * convert to new protocol * removed InvocationDescriptorRegistry because we're not yet sure about custom protocols * update SocialWeather sample * Moving ts client to using new protocol * make the functional tests a little easier to run on ctrl-f5 --- SignalR.sln | 9 +- build/dependencies.props | 3 +- .../HubConnection.spec.ts | 117 ++- .../HubConnection.ts | 116 ++- .../Startup.cs | 4 +- .../wwwroot/default.html | 13 + samples/ClientSample/HubSample.cs | 3 +- .../PersistentConnectionLifeTimeManager.cs | 7 +- .../EndPoints/MessagesEndPoint.cs | 4 +- .../SocketsSample/LineInvocationAdapter.cs | 91 -- .../Protobuf/ProtobufInvocationAdapter.cs | 168 ---- .../SocketsSample/Protobuf/RpcInvocation.cs | 848 ------------------ .../Protobuf/RpcInvocation.proto | 30 - samples/SocketsSample/ProtobufSerializer.cs | 36 - samples/SocketsSample/Startup.cs | 13 +- .../IOutputExtensions.cs | 0 .../HubConnection.cs | 275 ++++-- .../HubException.cs | 23 + .../IInvocationAdapter.cs | 16 - .../{ => Internal}/IInvocationBinder.cs | 2 +- .../Internal/Protocol/CompletionMessage.cs | 38 + .../Internal/Protocol/HubMessage.cs | 17 + .../HubProtocolWriteMessageExtensions.cs | 37 + .../Internal/Protocol/IHubProtocol.cs | 18 + .../Internal/Protocol/InvocationMessage.cs | 44 + .../Internal/Protocol/JsonHubProtocol.cs | 281 ++++++ .../Internal/Protocol/StreamItemMessage.cs | 20 + .../InvocationAdapterExtensions.cs | 16 - .../InvocationDescriptor.cs | 19 - .../InvocationResultDescriptor.cs | 17 - .../JsonNetInvocationAdapter.cs | 89 -- ...Microsoft.AspNetCore.SignalR.Common.csproj | 11 +- .../Microsoft.AspNetCore.SignalR.Redis.csproj | 1 - .../RedisHubLifetimeManager.cs | 153 ++-- .../DefaultHubLifetimeManager.cs | 63 +- src/Microsoft.AspNetCore.SignalR/Hub.cs | 5 +- .../HubConnectionMetadataNames.cs} | 10 +- .../HubEndPoint.cs | 143 ++- .../Internal/DefaultHubProtocolResolver.cs | 18 + .../Internal/IHubProtocolResolver.cs | 13 + .../InvocationAdapterRegistry.cs | 32 - .../Microsoft.AspNetCore.SignalR.csproj | 1 - .../SignalRDependencyInjectionExtensions.cs | 20 +- .../SignalROptions.cs | 20 - .../SignalROptionsSetup.cs | 17 - .../LongPollingTransport.cs | 3 +- ...Microsoft.AspNetCore.Sockets.Client.csproj | 1 - .../ServerSentEventsTransport.cs | 5 +- ...Microsoft.AspNetCore.Sockets.Common.csproj | 4 + .../ConnectionMetadataNames.cs | 12 + ... EndPointDependencyInjectionExtensions.cs} | 2 +- .../HttpConnectionDispatcher.cs | 17 +- .../Internal/ConnectionState.cs | 8 +- .../Microsoft.AspNetCore.Sockets.csproj | 1 - .../Transports/WebSocketsTransport.cs | 2 +- .../IWebSocketConnection.cs | 5 +- ...soft.Extensions.WebSockets.Internal.csproj | 1 - test/Common/TaskExtensions.cs | 31 +- .../HubConnectionTests.cs | 6 +- .../HubConnectionProtocolTests.cs | 156 ++++ .../HubConnectionTests.cs | 103 ++- .../TestConnection.cs | 116 +++ .../Internal/Protocol/JsonHubProtocolTests.cs | 257 ++++++ ...oft.AspNetCore.SignalR.Common.Tests.csproj | 31 + .../HubEndpointTests.cs | 137 +-- .../TestClient.cs | 129 +-- 66 files changed, 1957 insertions(+), 1951 deletions(-) create mode 100644 client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html delete mode 100644 samples/SocketsSample/LineInvocationAdapter.cs delete mode 100644 samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs delete mode 100644 samples/SocketsSample/Protobuf/RpcInvocation.cs delete mode 100644 samples/SocketsSample/Protobuf/RpcInvocation.proto delete mode 100644 samples/SocketsSample/ProtobufSerializer.cs rename src/{Microsoft.AspNetCore.Sockets.Common/Internal/Formatters => Common}/IOutputExtensions.cs (100%) create mode 100644 src/Microsoft.AspNetCore.SignalR.Client/HubException.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/IInvocationAdapter.cs rename src/Microsoft.AspNetCore.SignalR.Common/{ => Internal}/IInvocationBinder.cs (87%) create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolWriteMessageExtensions.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamItemMessage.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/InvocationAdapterExtensions.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/InvocationDescriptor.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/InvocationResultDescriptor.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs rename src/{Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs => Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs} (54%) create mode 100644 src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs create mode 100644 src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR/SignalROptions.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs create mode 100644 src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs rename src/Microsoft.AspNetCore.Sockets/{EndpointDependencyInjectionExtensions.cs => EndPointDependencyInjectionExtensions.cs} (93%) create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj diff --git a/SignalR.sln b/SignalR.sln index cb06185f7b..bdf0240d45 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26228.9 +VisualStudioVersion = 15.0.26411.1 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" EndProject @@ -74,6 +74,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "client-ts", "client-ts", "{ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Microbenchmarks", "test\Microsoft.AspNetCore.SignalR.Microbenchmarks\Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj", "{96771B3F-4D18-41A7-A75B-FF38E76AAC89}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -180,6 +182,10 @@ Global {96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Debug|Any CPU.Build.0 = Debug|Any CPU {96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.ActiveCfg = Release|Any CPU {96771B3F-4D18-41A7-A75B-FF38E76AAC89}.Release|Any CPU.Build.0 = Release|Any CPU + {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -211,5 +217,6 @@ Global {333526A4-633B-491A-AC45-CC62A0012D1C} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B} {6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131} + {75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131} EndGlobalSection EndGlobal diff --git a/build/dependencies.props b/build/dependencies.props index 17b9f395ac..266b49aaff 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -1,6 +1,7 @@ - + 0.4.0-* + 4.4.0-* 2.0.0-* 0.1.0-* 4.3.0 diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts index c3eae1fb51..82b8d0a591 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts @@ -4,7 +4,7 @@ import { DataReceived, ConnectionClosed } from "../Microsoft.AspNetCore.SignalR. import { TransportType, ITransport } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports" describe("HubConnection", () => { - it("completes pending invocations when stopped", async (done) => { + it("completes pending invocations when stopped", async done => { let connection: IConnection = { start(transportType: TransportType | ITransport): Promise { return Promise.resolve(); @@ -27,18 +27,18 @@ describe("HubConnection", () => { let hubConnection = new HubConnection(connection); var invokePromise = hubConnection.invoke("testMethod"); hubConnection.stop(); - invokePromise - .then(() => { - fail(); - done(); - }) - .catch((error: Error) => { - expect(error.message).toBe("Invocation cancelled due to connection being closed."); - done(); - }); + + try { + await invokePromise; + fail(); + } + catch (e) { + expect(e.message).toBe("Invocation cancelled due to connection being closed."); + } + done(); }); - it("completes pending invocations when connection is lost", async (done) => { + it("completes pending invocations when connection is lost", async done => { let connection: IConnection = { start(transportType: TransportType | ITransport): Promise { return Promise.resolve(); @@ -60,17 +60,92 @@ describe("HubConnection", () => { let hubConnection = new HubConnection(connection); var invokePromise = hubConnection.invoke("testMethod"); - invokePromise - .then(() => { - fail(); - done(); - }) - .catch((error: Error) => { - expect(error.message).toBe("Connection lost"); - done(); - }); - // Typically this would be called by the transport connection.onClosed(new Error("Connection lost")); + + try { + await invokePromise; + fail(); + } + catch (e) { + expect(e.message).toBe("Connection lost"); + } + done(); + }); + + it("sends invocations as nonblocking", async done => { + let dataSent: string; + let connection: IConnection = { + start(transportType: TransportType): Promise { + return Promise.resolve(); + }, + + send(data: any): Promise { + dataSent = data; + return Promise.resolve(); + }, + + stop(): void { + if (this.onClosed) { + this.onClosed(); + } + }, + + onDataReceived: null, + onClosed: null + }; + + let hubConnection = new HubConnection(connection); + let invokePromise = hubConnection.invoke("testMethod"); + + expect(JSON.parse(dataSent).nonblocking).toBe(false); + + // will clean pending promises + connection.onClosed(); + + try { + await invokePromise; + fail(); // exception is expected because the call has not completed + } + catch (e) { + } + done(); + }); + + it("rejects streaming responses", async done => { + let connection: IConnection = { + start(transportType: TransportType): Promise { + return Promise.resolve(); + }, + + send(data: any): Promise { + return Promise.resolve(); + }, + + stop(): void { + if (this.onClosed) { + this.onClosed(); + } + }, + + onDataReceived: null, + onClosed: null + }; + + let hubConnection = new HubConnection(connection); + let invokePromise = hubConnection.invoke("testMethod"); + + connection.onDataReceived("{ \"type\": 2, \"invocationId\": \"0\", \"result\": null }"); + connection.onClosed(); + + try { + await invokePromise; + fail(); + } + catch (e) { + expect(e.message).toBe("Streaming is not supported."); + } + + done(); }); }); \ No newline at end of file diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts index 7f6c78a95b..4967e0078d 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts @@ -3,16 +3,31 @@ import { IConnection } from "./IConnection" import { Connection } from "./Connection" import { TransportType } from "./Transports" -interface InvocationDescriptor { - readonly Id: string; - readonly Method: string; - readonly Arguments: Array; + +const enum MessageType { + Invocation = 1, + Result, + Completion } -interface InvocationResultDescriptor { - readonly Id: string; - readonly Error: string; - readonly Result: any; +interface HubMessage { + readonly type: MessageType; + readonly invocationId: string; +} + +interface InvocationMessage extends HubMessage { + readonly target: string; + readonly arguments: Array; + readonly nonblocking?: boolean; +} + +interface ResultMessage extends HubMessage { + readonly result?: any; +} + +interface CompletionMessage extends HubMessage { + readonly error?: string; + readonly result?: any; } export { Connection } from "./Connection" @@ -20,7 +35,7 @@ export { TransportType } from "./Transports" export class HubConnection { private connection: IConnection; - private callbacks: Map void>; + private callbacks: Map void>; private methods: Map void>; private id: number; private connectionClosedCallback: ConnectionClosed; @@ -40,7 +55,7 @@ export class HubConnection { this.onConnectionClosed(error); } - this.callbacks = new Map void>(); + this.callbacks = new Map void>(); this.methods = new Map void>(); this.id = 0; } @@ -51,34 +66,49 @@ export class HubConnection { if (!data) { return; } - var descriptor = JSON.parse(data); - if (descriptor.Method === undefined) { - let invocationResult: InvocationResultDescriptor = descriptor; - let callback = this.callbacks.get(invocationResult.Id); - if (callback != null) { - callback(invocationResult); - this.callbacks.delete(invocationResult.Id); + + var message = JSON.parse(data); + switch (message.type) { + case MessageType.Invocation: + this.InvokeClientMethod(message); + break; + case MessageType.Result: + // TODO: Streaming (MessageType.Result) currently not supported - callback will throw + case MessageType.Completion: + let callback = this.callbacks.get(message.invocationId); + if (callback != null) { + callback(message); + this.callbacks.delete(message.invocationId); + } + break; + default: + console.log("Invalid message type: " + data); + break; + } + } + + private InvokeClientMethod(invocationMessage: InvocationMessage) { + let method = this.methods.get(invocationMessage.target); + if (method) { + method.apply(this, invocationMessage.arguments); + if (!invocationMessage.nonblocking) { + // TODO: send result back to the server? } } else { - let invocation: InvocationDescriptor = descriptor; - let method = this.methods[invocation.Method]; - if (method != null) { - // TODO: bind? args? - method.apply(this, invocation.Arguments); - } + console.log(`No client method with the name '${invocationMessage.target}' found.`); } } private onConnectionClosed(error: Error) { - let errorInvocationResult = { - Id: "-1", - Error: error ? error.message : "Invocation cancelled due to connection being closed.", - Result: null - } as InvocationResultDescriptor; + let errorCompletionMessage = { + type: MessageType.Completion, + invocationId: "-1", + error: error ? error.message : "Invocation cancelled due to connection being closed.", + }; this.callbacks.forEach(callback => { - callback(errorInvocationResult); + callback(errorCompletionMessage); }); this.callbacks.clear(); @@ -99,19 +129,27 @@ export class HubConnection { let id = this.id; this.id++; - let invocationDescriptor: InvocationDescriptor = { - "Id": id.toString(), - "Method": methodName, - "Arguments": args + let invocationDescriptor: InvocationMessage = { + type: MessageType.Invocation, + invocationId: id.toString(), + target: methodName, + arguments: args, + nonblocking: false }; let p = new Promise((resolve, reject) => { - this.callbacks.set(invocationDescriptor.Id, (invocationResult: InvocationResultDescriptor) => { - if (invocationResult.Error != null) { - reject(new Error(invocationResult.Error)); + this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: CompletionMessage | ResultMessage) => { + if (invocationEvent.type === MessageType.Completion) { + let completionMessage = invocationEvent; + if (completionMessage.error) { + reject(new Error(completionMessage.error)); + } + else { + resolve(completionMessage.result); + } } else { - resolve(invocationResult.Result); + reject(new Error("Streaming is not supported.")) } }); @@ -119,7 +157,7 @@ export class HubConnection { this.connection.send(JSON.stringify(invocationDescriptor)) .catch(e => { reject(e); - this.callbacks.delete(invocationDescriptor.Id); + this.callbacks.delete(invocationDescriptor.invocationId); }); }); @@ -127,7 +165,7 @@ export class HubConnection { } on(methodName: string, method: (...args: any[]) => void) { - this.methods[methodName] = method; + this.methods.set(methodName, method); } set onClosed(callback: ConnectionClosed) { diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs index 9a513eacda..771101f605 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Builder; @@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server app.UseDeveloperExceptionPage(); } - app.UseStaticFiles(); + app.UseFileServer(); app.UseSockets(options => options.MapEndpoint("/echo")); app.UseSignalR(routes => { diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html new file mode 100644 index 0000000000..d06a274aff --- /dev/null +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/default.html @@ -0,0 +1,13 @@ + + + + + SignalR Tests + + +

SignalR Tests

+ + + diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index 049e20551d..a277899d00 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -5,7 +5,6 @@ using System; using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.Extensions.CommandLineUtils; using Microsoft.Extensions.Logging; @@ -33,7 +32,7 @@ namespace ClientSample var loggerFactory = new LoggerFactory(); Console.WriteLine("Connecting to {0}", baseUrl); - var connection = new HubConnection(new Uri(baseUrl), new JsonNetInvocationAdapter(), loggerFactory); + var connection = new HubConnection(new Uri(baseUrl), loggerFactory); try { await connection.StartAsync(); diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index f7fa6e25b3..899a3fc625 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -22,6 +22,7 @@ namespace SocialWeather public void OnConnectedAsync(Connection connection) { + connection.Metadata[ConnectionMetadataNames.Format] = "json"; _connectionList.Add(connection); } @@ -34,11 +35,11 @@ namespace SocialWeather { foreach (var connection in _connectionList) { - var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get("formatType")); + var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get(ConnectionMetadataNames.Format)); var ms = new MemoryStream(); await formatter.WriteAsync(data, ms); - var context = (HttpContext)connection.Metadata[typeof(HttpContext)]; + var context = (HttpContext)connection.Metadata[ConnectionMetadataNames.HttpContext]; var format = string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) ? MessageType.Binary diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index 5eec621674..3f68fa8f7d 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -17,7 +17,7 @@ namespace SocketsSample.EndPoints { Connections.Add(connection); - await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})"); + await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata[ConnectionMetadataNames.Transport]})"); try { @@ -37,7 +37,7 @@ namespace SocketsSample.EndPoints { Connections.Remove(connection); - await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})"); + await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata[ConnectionMetadataNames.Transport]})"); } } diff --git a/samples/SocketsSample/LineInvocationAdapter.cs b/samples/SocketsSample/LineInvocationAdapter.cs deleted file mode 100644 index e2a17ed1a5..0000000000 --- a/samples/SocketsSample/LineInvocationAdapter.cs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR; - -namespace SocketsSample -{ - public class LineInvocationAdapter : IInvocationAdapter - { - public async Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken) - { - var streamReader = new StreamReader(stream); - var line = await streamReader.ReadLineAsync(); - if (line == null) - { - return null; - } - - var values = line.Split(','); - - var type = values[0].Substring(0, 2); - var id = values[0].Substring(2); - - if (type.Equals("RI")) - { - var resultType = values[1].Substring(0, 1); - var result = values[1].Substring(1); - return new InvocationResultDescriptor() - { - Id = id, - Result = resultType.Equals("E") ? null : result, - Error = resultType.Equals("E") ? result : null, - }; - } - else - { - var method = values[1].Substring(1); - - return new InvocationDescriptor - { - Id = id, - Method = method, - Arguments = values.Skip(2).Zip(binder.GetParameterTypes(method), (v, t) => Convert.ChangeType(v, t)).ToArray() - }; - } - } - - public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken) - { - var invocationDescriptor = message as InvocationDescriptor; - if (invocationDescriptor != null) - { - return WriteInvocationDescriptorAsync(invocationDescriptor, stream); - } - else - { - return WriteInvocationResultAsync((InvocationResultDescriptor)message, stream); - } - } - - private Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream) - { - var msg = $"CI{invocationDescriptor.Id},M{invocationDescriptor.Method},{string.Join(",", invocationDescriptor.Arguments.Select(a => a.ToString()))}\n"; - return WriteAsync(msg, stream); - } - - private Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream) - { - if (string.IsNullOrEmpty(resultDescriptor.Error)) - { - return WriteAsync($"RI{resultDescriptor.Id},E{resultDescriptor.Error}\n", stream); - } - else - { - return WriteAsync($"RI{resultDescriptor.Id},R{(resultDescriptor.Result != null ? resultDescriptor.Result.ToString() : string.Empty)}\n", stream); - } - } - - private async Task WriteAsync(string msg, Stream stream) - { - var writer = new StreamWriter(stream); - await writer.WriteAsync(msg); - await writer.FlushAsync(); - } - } -} diff --git a/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs b/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs deleted file mode 100644 index 9d5501e8a7..0000000000 --- a/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Google.Protobuf; -using Microsoft.AspNetCore.SignalR; -using Microsoft.Extensions.DependencyInjection; - -namespace SocketsSample.Protobuf -{ - public class ProtobufInvocationAdapter : IInvocationAdapter - { - private IServiceProvider _serviceProvider; - - public ProtobufInvocationAdapter(IServiceProvider serviceProvider) - { - _serviceProvider = serviceProvider; - } - - public Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken) - { - return Task.Run(() => CreateInvocationMessageInt(stream, binder)); - } - - public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } - - private Task CreateInvocationMessageInt(Stream stream, IInvocationBinder binder) - { - var inputStream = new CodedInputStream(stream, leaveOpen: true); - var messageKind = new RpcMessageKind(); - inputStream.ReadMessage(messageKind); - if(messageKind.MessageKind == RpcMessageKind.Types.Kind.Invocation) - { - return CreateInvocationDescriptorInt(inputStream, binder); - } - else - { - return CreateInvocationResultDescriptorInt(inputStream, binder); - } - } - - private Task CreateInvocationResultDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder) - { - throw new NotImplementedException("Not yet implemented for Protobuf"); - } - - private Task CreateInvocationDescriptorInt(CodedInputStream inputStream, IInvocationBinder binder) - { - var invocationHeader = new RpcInvocationHeader(); - inputStream.ReadMessage(invocationHeader); - var argumentTypes = binder.GetParameterTypes(invocationHeader.Name); - - var invocationDescriptor = new InvocationDescriptor(); - invocationDescriptor.Method = invocationHeader.Name; - invocationDescriptor.Id = invocationHeader.Id.ToString(); - invocationDescriptor.Arguments = new object[argumentTypes.Length]; - - var primitiveParser = PrimitiveValue.Parser; - - for (var i = 0; i < argumentTypes.Length; i++) - { - if (typeof(int) == argumentTypes[i]) - { - var value = new PrimitiveValue(); - inputStream.ReadMessage(value); - invocationDescriptor.Arguments[i] = value.Int32Value; - } - else if (typeof(string) == argumentTypes[i]) - { - var value = new PrimitiveValue(); - inputStream.ReadMessage(value); - invocationDescriptor.Arguments[i] = value.StringValue; - } - else - { - var serializer = _serviceProvider.GetRequiredService(); - invocationDescriptor.Arguments[i] = serializer.GetValue(inputStream, argumentTypes[i]); - } - } - - return Task.FromResult(invocationDescriptor); - } - - public async Task WriteInvocationResultAsync(InvocationResultDescriptor resultDescriptor, Stream stream) - { - var outputStream = new CodedOutputStream(stream, leaveOpen: true); - outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result }); - - var resultHeader = new RpcInvocationResultHeader - { - Id = int.Parse(resultDescriptor.Id), - HasResult = resultDescriptor.Result != null - }; - - if (resultDescriptor.Error != null) - { - resultHeader.Error = resultDescriptor.Error; - } - - outputStream.WriteMessage(resultHeader); - - if (string.IsNullOrEmpty(resultHeader.Error) && resultDescriptor.Result != null) - { - var result = resultDescriptor.Result; - - if (result.GetType() == typeof(int)) - { - outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result }); - } - else if (result.GetType() == typeof(string)) - { - outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result }); - } - else - { - var serializer = _serviceProvider.GetRequiredService(); - var message = serializer.GetMessage(result); - outputStream.WriteMessage(message); - } - } - - outputStream.Flush(); - await stream.FlushAsync(); - } - - public async Task WriteInvocationDescriptorAsync(InvocationDescriptor invocationDescriptor, Stream stream) - { - var outputStream = new CodedOutputStream(stream, leaveOpen: true); - outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation }); - - var invocationHeader = new RpcInvocationHeader() - { - Id = 0, - Name = invocationDescriptor.Method, - NumArgs = invocationDescriptor.Arguments.Length - }; - - outputStream.WriteMessage(invocationHeader); - - foreach (var arg in invocationDescriptor.Arguments) - { - if (arg.GetType() == typeof(int)) - { - outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg }); - } - else if (arg.GetType() == typeof(string)) - { - outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg }); - } - else - { - var serializer = _serviceProvider.GetRequiredService(); - var message = serializer.GetMessage(arg); - outputStream.WriteMessage(message); - } - } - - outputStream.Flush(); - await stream.FlushAsync(); - } - } -} diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.cs b/samples/SocketsSample/Protobuf/RpcInvocation.cs deleted file mode 100644 index a20fc10e4f..0000000000 --- a/samples/SocketsSample/Protobuf/RpcInvocation.cs +++ /dev/null @@ -1,848 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: RpcInvocation.proto -#pragma warning disable 1591, 0612, 3021 -#region Designer generated code - -using pb = global::Google.Protobuf; -using pbc = global::Google.Protobuf.Collections; -using pbr = global::Google.Protobuf.Reflection; -using scg = global::System.Collections.Generic; -/// Holder for reflection information generated from RpcInvocation.proto -public static partial class RpcInvocationReflection { - - #region Descriptor - /// File descriptor for RpcInvocation.proto - public static pbr::FileDescriptor Descriptor { - get { return descriptor; } - } - private static pbr::FileDescriptor descriptor; - - static RpcInvocationReflection() { - byte[] descriptorData = global::System.Convert.FromBase64String( - string.Concat( - "ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l", - "c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k", - "EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u", - "SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD", - "IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF", - "EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp", - "dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY", - "AiABKAlIAEIICgZvbmVvZl8iKgoNUGVyc29uTWVzc2FnZRIMCgROYW1lGAEg", - "ASgJEgsKA0FnZRgCIAEoA2IGcHJvdG8z")); - descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null), - new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::PersonMessage), global::PersonMessage.Parser, new[]{ "Name", "Age" }, null, null, null) - })); - } - #endregion - -} -#region Messages -public sealed partial class RpcMessageKind : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcMessageKind()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcMessageKind() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcMessageKind(RpcMessageKind other) : this() { - messageKind_ = other.messageKind_; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcMessageKind Clone() { - return new RpcMessageKind(this); - } - - /// Field number for the "MessageKind" field. - public const int MessageKindFieldNumber = 1; - private global::RpcMessageKind.Types.Kind messageKind_ = 0; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::RpcMessageKind.Types.Kind MessageKind { - get { return messageKind_; } - set { - messageKind_ = value; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as RpcMessageKind); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(RpcMessageKind other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (MessageKind != other.MessageKind) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (MessageKind != 0) hash ^= MessageKind.GetHashCode(); - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (MessageKind != 0) { - output.WriteRawTag(8); - output.WriteEnum((int) MessageKind); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (MessageKind != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(RpcMessageKind other) { - if (other == null) { - return; - } - if (other.MessageKind != 0) { - MessageKind = other.MessageKind; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum(); - break; - } - } - } - } - - #region Nested types - /// Container for nested types declared in the RpcMessageKind message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static partial class Types { - public enum Kind { - [pbr::OriginalName("Result")] Result = 0, - [pbr::OriginalName("Invocation")] Invocation = 1, - } - - } - #endregion - -} - -public sealed partial class RpcInvocationHeader : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationHeader()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationHeader() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationHeader(RpcInvocationHeader other) : this() { - name_ = other.name_; - id_ = other.id_; - numArgs_ = other.numArgs_; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationHeader Clone() { - return new RpcInvocationHeader(this); - } - - /// Field number for the "Name" field. - public const int NameFieldNumber = 1; - private string name_ = ""; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } - - /// Field number for the "Id" field. - public const int IdFieldNumber = 2; - private int id_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int Id { - get { return id_; } - set { - id_ = value; - } - } - - /// Field number for the "NumArgs" field. - public const int NumArgsFieldNumber = 3; - private int numArgs_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int NumArgs { - get { return numArgs_; } - set { - numArgs_ = value; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as RpcInvocationHeader); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(RpcInvocationHeader other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Name != other.Name) return false; - if (Id != other.Id) return false; - if (NumArgs != other.NumArgs) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (Id != 0) hash ^= Id.GetHashCode(); - if (NumArgs != 0) hash ^= NumArgs.GetHashCode(); - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Name.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Name); - } - if (Id != 0) { - output.WriteRawTag(16); - output.WriteInt32(Id); - } - if (NumArgs != 0) { - output.WriteRawTag(24); - output.WriteInt32(NumArgs); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (Id != 0) { - size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); - } - if (NumArgs != 0) { - size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumArgs); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(RpcInvocationHeader other) { - if (other == null) { - return; - } - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.Id != 0) { - Id = other.Id; - } - if (other.NumArgs != 0) { - NumArgs = other.NumArgs; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Name = input.ReadString(); - break; - } - case 16: { - Id = input.ReadInt32(); - break; - } - case 24: { - NumArgs = input.ReadInt32(); - break; - } - } - } - } - -} - -public sealed partial class RpcInvocationResultHeader : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationResultHeader()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationResultHeader() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() { - id_ = other.id_; - hasResult_ = other.hasResult_; - error_ = other.error_; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public RpcInvocationResultHeader Clone() { - return new RpcInvocationResultHeader(this); - } - - /// Field number for the "Id" field. - public const int IdFieldNumber = 1; - private int id_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int Id { - get { return id_; } - set { - id_ = value; - } - } - - /// Field number for the "HasResult" field. - public const int HasResultFieldNumber = 2; - private bool hasResult_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool HasResult { - get { return hasResult_; } - set { - hasResult_ = value; - } - } - - /// Field number for the "Error" field. - public const int ErrorFieldNumber = 3; - private string error_ = ""; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Error { - get { return error_; } - set { - error_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as RpcInvocationResultHeader); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(RpcInvocationResultHeader other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Id != other.Id) return false; - if (HasResult != other.HasResult) return false; - if (Error != other.Error) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Id != 0) hash ^= Id.GetHashCode(); - if (HasResult != false) hash ^= HasResult.GetHashCode(); - if (Error.Length != 0) hash ^= Error.GetHashCode(); - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Id != 0) { - output.WriteRawTag(8); - output.WriteInt32(Id); - } - if (HasResult != false) { - output.WriteRawTag(16); - output.WriteBool(HasResult); - } - if (Error.Length != 0) { - output.WriteRawTag(26); - output.WriteString(Error); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Id != 0) { - size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); - } - if (HasResult != false) { - size += 1 + 1; - } - if (Error.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Error); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(RpcInvocationResultHeader other) { - if (other == null) { - return; - } - if (other.Id != 0) { - Id = other.Id; - } - if (other.HasResult != false) { - HasResult = other.HasResult; - } - if (other.Error.Length != 0) { - Error = other.Error; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - Id = input.ReadInt32(); - break; - } - case 16: { - HasResult = input.ReadBool(); - break; - } - case 26: { - Error = input.ReadString(); - break; - } - } - } - } - -} - -public sealed partial class PrimitiveValue : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PrimitiveValue()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PrimitiveValue() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PrimitiveValue(PrimitiveValue other) : this() { - switch (other.OneofCase) { - case OneofOneofCase.Int32Value: - Int32Value = other.Int32Value; - break; - case OneofOneofCase.StringValue: - StringValue = other.StringValue; - break; - } - - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PrimitiveValue Clone() { - return new PrimitiveValue(this); - } - - /// Field number for the "Int32Value" field. - public const int Int32ValueFieldNumber = 1; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int Int32Value { - get { return oneofCase_ == OneofOneofCase.Int32Value ? (int) oneof_ : 0; } - set { - oneof_ = value; - oneofCase_ = OneofOneofCase.Int32Value; - } - } - - /// Field number for the "StringValue" field. - public const int StringValueFieldNumber = 2; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string StringValue { - get { return oneofCase_ == OneofOneofCase.StringValue ? (string) oneof_ : ""; } - set { - oneof_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - oneofCase_ = OneofOneofCase.StringValue; - } - } - - private object oneof_; - /// Enum of possible cases for the "oneof_" oneof. - public enum OneofOneofCase { - None = 0, - Int32Value = 1, - StringValue = 2, - } - private OneofOneofCase oneofCase_ = OneofOneofCase.None; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public OneofOneofCase OneofCase { - get { return oneofCase_; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void ClearOneof() { - oneofCase_ = OneofOneofCase.None; - oneof_ = null; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as PrimitiveValue); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(PrimitiveValue other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Int32Value != other.Int32Value) return false; - if (StringValue != other.StringValue) return false; - if (OneofCase != other.OneofCase) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (oneofCase_ == OneofOneofCase.Int32Value) hash ^= Int32Value.GetHashCode(); - if (oneofCase_ == OneofOneofCase.StringValue) hash ^= StringValue.GetHashCode(); - hash ^= (int) oneofCase_; - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (oneofCase_ == OneofOneofCase.Int32Value) { - output.WriteRawTag(8); - output.WriteInt32(Int32Value); - } - if (oneofCase_ == OneofOneofCase.StringValue) { - output.WriteRawTag(18); - output.WriteString(StringValue); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (oneofCase_ == OneofOneofCase.Int32Value) { - size += 1 + pb::CodedOutputStream.ComputeInt32Size(Int32Value); - } - if (oneofCase_ == OneofOneofCase.StringValue) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(StringValue); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(PrimitiveValue other) { - if (other == null) { - return; - } - switch (other.OneofCase) { - case OneofOneofCase.Int32Value: - Int32Value = other.Int32Value; - break; - case OneofOneofCase.StringValue: - StringValue = other.StringValue; - break; - } - - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - Int32Value = input.ReadInt32(); - break; - } - case 18: { - StringValue = input.ReadString(); - break; - } - } - } - } - -} - -public sealed partial class PersonMessage : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PersonMessage()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[4]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PersonMessage() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PersonMessage(PersonMessage other) : this() { - name_ = other.name_; - age_ = other.age_; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public PersonMessage Clone() { - return new PersonMessage(this); - } - - /// Field number for the "Name" field. - public const int NameFieldNumber = 1; - private string name_ = ""; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } - - /// Field number for the "Age" field. - public const int AgeFieldNumber = 2; - private long age_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long Age { - get { return age_; } - set { - age_ = value; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as PersonMessage); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(PersonMessage other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Name != other.Name) return false; - if (Age != other.Age) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (Age != 0L) hash ^= Age.GetHashCode(); - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Name.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Name); - } - if (Age != 0L) { - output.WriteRawTag(16); - output.WriteInt64(Age); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (Age != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(Age); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(PersonMessage other) { - if (other == null) { - return; - } - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.Age != 0L) { - Age = other.Age; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Name = input.ReadString(); - break; - } - case 16: { - Age = input.ReadInt64(); - break; - } - } - } - } - -} - -#endregion - - -#endregion Designer generated code diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.proto b/samples/SocketsSample/Protobuf/RpcInvocation.proto deleted file mode 100644 index f63ad75a05..0000000000 --- a/samples/SocketsSample/Protobuf/RpcInvocation.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -message RpcMessageKind { - enum Kind { Result = 0; Invocation = 1; } - Kind MessageKind = 1; -} - -message RpcInvocationHeader { - string Name = 1; - int32 Id = 2; - int32 NumArgs = 3; -} - -message RpcInvocationResultHeader { - int32 Id = 1; - bool HasResult = 2; - string Error = 3; -} - -message PrimitiveValue { - oneof oneof_ { - int32 Int32Value = 1; - string StringValue = 2; - } -} - -message PersonMessage { - string Name = 1; - int64 Age = 2; -} \ No newline at end of file diff --git a/samples/SocketsSample/ProtobufSerializer.cs b/samples/SocketsSample/ProtobufSerializer.cs deleted file mode 100644 index 3b1aaec672..0000000000 --- a/samples/SocketsSample/ProtobufSerializer.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Google.Protobuf; -using SocketsSample.Hubs; - -namespace SocketsSample -{ - public class ProtobufSerializer - { - public object GetValue(CodedInputStream inputStream, Type type) - { - if (type == typeof(Person)) - { - var value = new PersonMessage(); - inputStream.ReadMessage(value); - - return new Person { Name = value.Name, Age = value.Age }; - } - - throw new InvalidOperationException("(Deserialize) Unknown type."); - } - - public IMessage GetMessage(object value) - { - Person person = value as Person; - if (person != null) - { - return new PersonMessage { Name = person.Name, Age = person.Age }; - } - - throw new InvalidOperationException("(Serialize) Unknown type."); - } - } -} diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index b8f5f3feb4..93b0232efc 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -3,12 +3,10 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using SocketsSample.EndPoints; using SocketsSample.Hubs; -using SocketsSample.Protobuf; namespace SocketsSample { @@ -18,21 +16,12 @@ namespace SocketsSample // For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940 public void ConfigureServices(IServiceCollection services) { - services.AddSingleton(); - services.AddSingleton(); - services.AddSockets(); - services.AddSignalR(options => - { - options.RegisterInvocationAdapter("protobuf"); - options.RegisterInvocationAdapter("line"); - }); + services.AddSignalR(); // .AddRedis(); services.AddEndPoint(); - - services.AddSingleton(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs b/src/Common/IOutputExtensions.cs similarity index 100% rename from src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs rename to src/Common/IOutputExtensions.cs diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index c1b946e9e1..fa7bad1615 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -1,19 +1,20 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; -using System.IO; using System.Linq; using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Client { @@ -22,17 +23,14 @@ namespace Microsoft.AspNetCore.SignalR.Client private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly IConnection _connection; - private readonly IInvocationAdapter _adapter; + private readonly IHubProtocol _protocol; private readonly HubBinder _binder; private HttpClient _httpClient; - private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); - - // We need to ensure pending calls added after a connection failure don't hang. Right now the easiest thing to do is lock. private readonly object _pendingCallsLock = new object(); + private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); - private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(); private int _nextId = 0; @@ -50,27 +48,33 @@ namespace Microsoft.AspNetCore.SignalR.Client } public HubConnection(Uri url) - : this(new Connection(url), new JsonNetInvocationAdapter(), null) + : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), null) { } public HubConnection(Uri url, ILoggerFactory loggerFactory) - : this(new Connection(url), new JsonNetInvocationAdapter(), loggerFactory) + : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), loggerFactory) { } - public HubConnection(Uri url, IInvocationAdapter adapter, ILoggerFactory loggerFactory) - : this(new Connection(url, loggerFactory), adapter, loggerFactory) + // These are only really needed for tests now... + public HubConnection(IConnection connection, ILoggerFactory loggerFactory) + : this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory) { } - public HubConnection(IConnection connection, IInvocationAdapter adapter, ILoggerFactory loggerFactory) + public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) { if (connection == null) { throw new ArgumentNullException(nameof(connection)); } + if (protocol == null) + { + throw new ArgumentNullException(nameof(protocol)); + } + _connection = connection; _binder = new HubBinder(this); - _adapter = adapter; + _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _connection.Received += OnDataReceived; @@ -100,6 +104,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public async Task DisposeAsync() { await _connection.DisposeAsync(); + _httpClient?.Dispose(); } @@ -118,81 +123,80 @@ namespace Microsoft.AspNetCore.SignalR.Client public Task Invoke(string methodName, Type returnType, params object[] args) => Invoke(methodName, returnType, CancellationToken.None, args); public async Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { - _logger.LogTrace("Preparing invocation of '{0}', with return type '{1}' and {2} args", methodName, returnType.AssemblyQualifiedName, args.Length); + ThrowIfConnectionTerminated(); + _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, returnType.AssemblyQualifiedName, args.Length); - // Create an invocation descriptor. - var descriptor = new InvocationDescriptor - { - Id = GetNextId(), - Method = methodName, - Arguments = args - }; + // Create an invocation descriptor. Client invocations are always blocking + var invocationMessage = new InvocationMessage(GetNextId(), nonBlocking: false, target: methodName, arguments: args); // I just want an excuse to use 'irq' as a variable name... - _logger.LogDebug("Registering Invocation ID '{0}' for tracking", descriptor.Id); - var irq = new InvocationRequest(cancellationToken, returnType); + _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); + var irq = new InvocationRequest(cancellationToken, returnType, invocationMessage.InvocationId, _loggerFactory); - lock (_pendingCallsLock) - { - if (_connectionActive.IsCancellationRequested) - { - throw new InvalidOperationException("Connection has been terminated."); - } - _pendingCalls.Add(descriptor.Id, irq); - } + AddInvocation(irq); - // Trace the invocation, but only if that logging level is enabled (because building the args list is a bit slow) + // Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow) if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", args.Select(a => a.GetType().FullName)); - _logger.LogTrace("Invocation #{0}: {1} {2}({3})", descriptor.Id, returnType.FullName, methodName, argsList); + _logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, returnType.FullName, methodName, argsList); } try { - var ms = new MemoryStream(); - await _adapter.WriteMessageAsync(descriptor, ms, cancellationToken); + var payload = await _protocol.WriteToArrayAsync(invocationMessage); - _logger.LogInformation("Sending Invocation #{0}", descriptor.Id); + _logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId); - // TODO: Format.Text - who, where and when decides about the format of outgoing messages - await _connection.SendAsync(ms.ToArray(), MessageType.Text, cancellationToken); - _logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id); + await _connection.SendAsync(payload, _protocol.MessageType, cancellationToken); + _logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId); } catch (Exception ex) { - _logger.LogError(0, ex, "Sending Invocation #{0} failed", descriptor.Id); - irq.Completion.TrySetException(ex); - lock (_pendingCallsLock) - { - _pendingCalls.Remove(descriptor.Id); - } + _logger.LogError(0, ex, "Sending Invocation '{invocationId}' failed", invocationMessage.InvocationId); + irq.Fail(ex); + TryRemoveInvocation(invocationMessage.InvocationId, out _); } // Return the completion task. It will be completed by ReceiveMessages when the response is received. - return await irq.Completion.Task; + return await irq.Completion; } - private async void OnDataReceived(byte[] data, MessageType messageType) + private void OnDataReceived(byte[] data, MessageType messageType) { - var message - = await _adapter.ReadMessageAsync(new MemoryStream(data), _binder, _connectionActive.Token); + var message = _protocol.ParseMessage(data, _binder); + InvocationRequest irq; switch (message) { - case InvocationDescriptor invocationDescriptor: - DispatchInvocation(invocationDescriptor, _connectionActive.Token); - break; - case InvocationResultDescriptor invocationResultDescriptor: - InvocationRequest irq; - lock (_pendingCallsLock) + case InvocationMessage invocation: + if (_logger.IsEnabled(LogLevel.Trace)) { - _connectionActive.Token.ThrowIfCancellationRequested(); - irq = _pendingCalls[invocationResultDescriptor.Id]; - _pendingCalls.Remove(invocationResultDescriptor.Id); + var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName)); + _logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList); } - DispatchInvocationResult(invocationResultDescriptor, irq, _connectionActive.Token); + DispatchInvocation(invocation, _connectionActive.Token); break; + case CompletionMessage completion: + if (!TryRemoveInvocation(completion.InvocationId, out irq)) + { + _logger.LogWarning("Dropped unsolicited Completion message for invocation '{invocationId}'", completion.InvocationId); + return; + } + DispatchInvocationCompletion(completion, irq); + irq.Dispose(); + break; + case StreamItemMessage streamItem: + // Complete the invocation with an error, we don't support streaming (yet) + if (!TryRemoveInvocation(streamItem.InvocationId, out irq)) + { + _logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId); + return; + } + irq.Fail(new NotSupportedException("Streaming method results are not supported")); + break; + default: + throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); } } @@ -201,74 +205,118 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.LogTrace("Shutting down connection"); if (ex != null) { - _logger.LogError("Connection is shutting down due to an error: {0}", ex); + _logger.LogError(ex, "Connection is shutting down due to an error"); } lock (_pendingCallsLock) { + // We cancel inside the lock to make sure everyone who was part-way through registering an invocation + // completes. This also ensures that nobody will add things to _pendingCalls after we leave this block + // because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock) _connectionActive.Cancel(); - foreach (var call in _pendingCalls.Values) + + foreach (var outstandingCall in _pendingCalls.Values) { + _logger.LogTrace("Removing pending call {invocationId}", outstandingCall.InvocationId); if (ex != null) { - call.Completion.TrySetException(ex); - } - else - { - call.Completion.TrySetCanceled(); + outstandingCall.Fail(ex); } + outstandingCall.Dispose(); } _pendingCalls.Clear(); } } - private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken) + private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler - if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler)) + if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler)) { - _logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method); + _logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target); return; } // TODO: Return values // TODO: Dispatch to a sync context to ensure we aren't blocking this loop. - handler.Handler(invocationDescriptor.Arguments); + handler.Handler(invocation.Arguments); } - private void DispatchInvocationResult(InvocationResultDescriptor result, InvocationRequest irq, CancellationToken cancellationToken) + private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq) { - _logger.LogInformation("Received Result for Invocation #{0}", result.Id); + _logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId); - if (cancellationToken.IsCancellationRequested) + if (irq.CancellationToken.IsCancellationRequested) { - return; + _logger.LogTrace("Cancelling dispatch of Completion message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId); } - - Debug.Assert(irq.Completion != null, "Didn't properly capture InvocationRequest in callback for ReadInvocationResultDescriptorAsync"); - - // If the invocation hasn't been cancelled, dispatch the result - if (!irq.CancellationToken.IsCancellationRequested) + else { - irq.Registration.Dispose(); - - // Complete the request based on the result - // TODO: the TrySetXYZ methods will cause continuations attached to the Task to run, so we should dispatch to a sync context or thread pool. - if (!string.IsNullOrEmpty(result.Error)) + if (!string.IsNullOrEmpty(completion.Error)) { - _logger.LogInformation("Completing Invocation #{0} with error: {1}", result.Id, result.Error); - irq.Completion.TrySetException(new Exception(result.Error)); + irq.Fail(new HubException(completion.Error)); } else { - _logger.LogInformation("Completing Invocation #{0} with result of type: {1}", result.Id, result.Result?.GetType()?.FullName ?? "<>"); - irq.Completion.TrySetResult(result.Result); + irq.Complete(completion.Result); } } } + private void ThrowIfConnectionTerminated() + { + if (_connectionActive.Token.IsCancellationRequested) + { + _logger.LogError("Invoke was called after the connection was terminated"); + throw new InvalidOperationException("Connection has been terminated."); + } + } + private string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); + private void AddInvocation(InvocationRequest irq) + { + lock (_pendingCallsLock) + { + ThrowIfConnectionTerminated(); + if (_pendingCalls.ContainsKey(irq.InvocationId)) + { + _logger.LogCritical("Invocation ID '{invocationId}' is already in use.", irq.InvocationId); + throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); + } + else + { + _pendingCalls.Add(irq.InvocationId, irq); + } + } + } + + private bool TryGetInvocation(string invocationId, out InvocationRequest irq) + { + lock (_pendingCallsLock) + { + ThrowIfConnectionTerminated(); + return _pendingCalls.TryGetValue(invocationId, out irq); + } + } + + private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) + { + lock (_pendingCallsLock) + { + ThrowIfConnectionTerminated(); + if (_pendingCalls.TryGetValue(invocationId, out irq)) + { + _pendingCalls.Remove(invocationId); + return true; + } + else + { + return false; + } + } + } + private class HubBinder : IInvocationBinder { private HubConnection _connection; @@ -282,7 +330,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq)) { - _connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId); + _connection._logger.LogError("Unsolicited response received for invocation '{invocationId}'", invocationId); return null; } return irq.ResultType; @@ -292,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler)) { - _connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName); + _connection._logger.LogWarning("Failed to find handler for '{target}' method", methodName); return Type.EmptyTypes; } return handler.ParameterTypes; @@ -311,20 +359,51 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private struct InvocationRequest + private class InvocationRequest : IDisposable { + private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly CancellationTokenRegistration _cancellationTokenRegistration; + private readonly ILogger _logger; + public Type ResultType { get; } public CancellationToken CancellationToken { get; } - public CancellationTokenRegistration Registration { get; } - public TaskCompletionSource Completion { get; } + public string InvocationId { get; } - public InvocationRequest(CancellationToken cancellationToken, Type resultType) + public Task Completion => _completionSource.Task; + + + public InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - Completion = tcs; + _logger = loggerFactory.CreateLogger(); + _cancellationTokenRegistration = cancellationToken.Register(() => _completionSource.TrySetCanceled()); + + InvocationId = invocationId; CancellationToken = cancellationToken; - Registration = cancellationToken.Register(() => tcs.TrySetCanceled()); ResultType = resultType; + + _logger.LogTrace("Invocation {invocationId} created", InvocationId); + } + + public void Fail(Exception exception) + { + _logger.LogTrace("Invocation {invocationId} marked as failed", InvocationId); + _completionSource.TrySetException(exception); + } + + public void Complete(object result) + { + _logger.LogTrace("Invocation {invocationId} marked as completed", InvocationId); + _completionSource.TrySetResult(result); + } + + public void Dispose() + { + _logger.LogTrace("Invocation {invocationId} disposed", InvocationId); + + // Just in case it hasn't already been completed + _completionSource.TrySetCanceled(); + + _cancellationTokenRegistration.Dispose(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubException.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubException.cs new file mode 100644 index 0000000000..b2a9667ea2 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubException.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Client +{ + [Serializable] + public class HubException : Exception + { + public HubException() + { + } + + public HubException(string message) : base(message) + { + } + + public HubException(string message, Exception innerException) : base(message, innerException) + { + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/IInvocationAdapter.cs b/src/Microsoft.AspNetCore.SignalR.Common/IInvocationAdapter.cs deleted file mode 100644 index b92a7864ee..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/IInvocationAdapter.cs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR -{ - public interface IInvocationAdapter - { - Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken); - - Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken); - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/IInvocationBinder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/IInvocationBinder.cs similarity index 87% rename from src/Microsoft.AspNetCore.SignalR.Common/IInvocationBinder.cs rename to src/Microsoft.AspNetCore.SignalR.Common/Internal/IInvocationBinder.cs index 7b56fc36eb..d1fa6057b1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/IInvocationBinder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/IInvocationBinder.cs @@ -3,7 +3,7 @@ using System; -namespace Microsoft.AspNetCore.SignalR +namespace Microsoft.AspNetCore.SignalR.Internal { public interface IInvocationBinder { diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs new file mode 100644 index 0000000000..44b7206121 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class CompletionMessage : HubMessage + { + public string Error { get; } + public object Result { get; } + public bool HasResult { get; } + + public CompletionMessage(string invocationId, string error, object result, bool hasResult) : base(invocationId) + { + if (error != null && result != null) + { + throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); + } + Error = error; + Result = result; + HasResult = hasResult; + } + + public override string ToString() + { + var errorStr = Error == null ? "<>" : $"\"{Error}\""; + var resultField = HasResult ? $", {nameof(Result)}: {Result ?? "<>"}" : string.Empty; + return $"Completion {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Error)}: {errorStr}{resultField} }}"; + } + + // Static factory methods. Don't want to use constructor overloading because it will break down + // if you need to send a payload statically-typed as a string. And because a static factory is clearer here + public static CompletionMessage WithError(string invocationId, string error) => new CompletionMessage(invocationId, error, result: null, hasResult: false); + + public static CompletionMessage WithResult(string invocationId, object payload) => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true); + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs new file mode 100644 index 0000000000..c5b8bae125 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public abstract class HubMessage + { + public string InvocationId { get; } + + protected HubMessage(string invocationId) + { + InvocationId = invocationId; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolWriteMessageExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolWriteMessageExtensions.cs new file mode 100644 index 0000000000..b53f2909ee --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolWriteMessageExtensions.cs @@ -0,0 +1,37 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.IO.Pipelines; +using System.IO.Pipelines.Text.Primitives; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public static class HubProtocolWriteMessageExtensions + { + public static async ValueTask WriteToArrayAsync(this IHubProtocol protocol, HubMessage message) + { + using (var memoryStream = new MemoryStream()) + { + var pipe = memoryStream.AsPipelineWriter(); + + // See https://github.com/dotnet/corefxlab/issues/1460, the TextEncoder is unimportant but required. + var output = new PipelineTextOutput(pipe, TextEncoder.Utf8); + + // Encode the message + if (!protocol.TryWriteMessage(message, output)) + { + throw new InvalidOperationException("Failed to write message to the output stream"); + } + + await output.FlushAsync(); + + // Create a message + return memoryStream.ToArray(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs new file mode 100644 index 0000000000..eea83358c1 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using Microsoft.AspNetCore.Sockets; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public interface IHubProtocol + { + MessageType MessageType { get; } + + HubMessage ParseMessage(ReadOnlySpan input, IInvocationBinder binder); + + bool TryWriteMessage(HubMessage message, IOutput output); + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs new file mode 100644 index 0000000000..4f8a4c738f --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class InvocationMessage : HubMessage + { + public string Target { get; } + + public object[] Arguments { get; } + + public bool NonBlocking { get; } + + public InvocationMessage(string invocationId, bool nonBlocking, string target, params object[] arguments) : base(invocationId) + { + if (string.IsNullOrEmpty(invocationId)) + { + throw new ArgumentNullException(nameof(invocationId)); + } + + if (string.IsNullOrEmpty(target)) + { + throw new ArgumentNullException(nameof(target)); + } + + if (arguments == null) + { + throw new ArgumentNullException(nameof(arguments)); + } + + Target = target; + Arguments = arguments; + NonBlocking = nonBlocking; + } + + public override string ToString() + { + return $"Invocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(NonBlocking)}: {NonBlocking}, {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments.Select(a => a?.ToString()))} ] }}"; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs new file mode 100644 index 0000000000..41e4db3d7a --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -0,0 +1,281 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using Microsoft.AspNetCore.Sockets; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class JsonHubProtocol : IHubProtocol + { + private const string ResultPropertyName = "result"; + private const string InvocationIdPropertyName = "invocationId"; + private const string TypePropertyName = "type"; + private const string ErrorPropertyName = "error"; + private const string TargetPropertyName = "target"; + private const string NonBlockingPropertyName = "nonBlocking"; + private const string ArgumentsPropertyName = "arguments"; + + private const int InvocationMessageType = 1; + private const int ResultMessageType = 2; + private const int CompletionMessageType = 3; + + // ONLY to be used for application payloads (args, return values, etc.) + private JsonSerializer _payloadSerializer; + + public MessageType MessageType => MessageType.Text; + + /// + /// Creates an instance of the using the specified + /// to serialize application payloads (arguments, results, etc.). The serialization of the outer protocol can + /// NOT be changed using this serializer. + /// + /// The to use to serialize application payloads (arguments, results, etc.). + public JsonHubProtocol(JsonSerializer payloadSerializer) + { + if (payloadSerializer == null) + { + throw new ArgumentNullException(nameof(payloadSerializer)); + } + + _payloadSerializer = payloadSerializer; + } + + public HubMessage ParseMessage(ReadOnlySpan input, IInvocationBinder binder) + { + // TODO: Need a span-native JSON parser! + using (var memoryStream = new MemoryStream(input.ToArray())) + { + return ParseMessage(memoryStream, binder); + } + } + + public bool TryWriteMessage(HubMessage message, IOutput output) + { + // TODO: Need IOutput-compatible JSON serializer! + using (var memoryStream = new MemoryStream()) + { + WriteMessage(message, memoryStream); + memoryStream.Flush(); + + return output.TryWrite(memoryStream.ToArray()); + } + } + + private HubMessage ParseMessage(Stream input, IInvocationBinder binder) + { + using (var reader = new JsonTextReader(new StreamReader(input))) + { + try + { + // PERF: Could probably use the JsonTextReader directly for better perf and fewer allocations + var token = JToken.ReadFrom(reader); + if (token == null) + { + return null; + } + + if (token.Type != JTokenType.Object) + { + throw new FormatException($"Unexpected JSON Token Type '{token.Type}'. Expected a JSON Object."); + } + + var json = (JObject)token; + + // Determine the type of the message + var type = GetRequiredProperty(json, TypePropertyName, JTokenType.Integer); + switch (type) + { + case InvocationMessageType: + return BindInvocationMessage(json, binder); + case ResultMessageType: + return BindResultMessage(json, binder); + case CompletionMessageType: + return BindCompletionMessage(json, binder); + default: + throw new FormatException($"Unknown message type: {type}"); + } + } + catch (JsonReaderException jrex) + { + throw new FormatException("Error reading JSON.", jrex); + } + } + } + + private void WriteMessage(HubMessage message, Stream stream) + { + using (var writer = new JsonTextWriter(new StreamWriter(stream))) + { + switch (message) + { + case InvocationMessage m: + WriteInvocationMessage(m, writer); + break; + case StreamItemMessage m: + WriteResultMessage(m, writer); + break; + case CompletionMessage m: + WriteCompletionMessage(m, writer); + break; + default: + throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); + } + } + } + + private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter writer) + { + writer.WriteStartObject(); + WriteHubMessageCommon(message, writer, CompletionMessageType); + if (!string.IsNullOrEmpty(message.Error)) + { + writer.WritePropertyName(ErrorPropertyName); + writer.WriteValue(message.Error); + } + else if (message.HasResult) + { + writer.WritePropertyName(ResultPropertyName); + _payloadSerializer.Serialize(writer, message.Result); + } + writer.WriteEndObject(); + } + + private void WriteResultMessage(StreamItemMessage message, JsonTextWriter writer) + { + writer.WriteStartObject(); + WriteHubMessageCommon(message, writer, ResultMessageType); + writer.WritePropertyName(ResultPropertyName); + _payloadSerializer.Serialize(writer, message.Item); + writer.WriteEndObject(); + } + + private void WriteInvocationMessage(InvocationMessage message, JsonTextWriter writer) + { + writer.WriteStartObject(); + WriteHubMessageCommon(message, writer, InvocationMessageType); + writer.WritePropertyName(TargetPropertyName); + writer.WriteValue(message.Target); + + if (message.NonBlocking) + { + writer.WritePropertyName(NonBlockingPropertyName); + writer.WriteValue(message.NonBlocking); + } + + writer.WritePropertyName(ArgumentsPropertyName); + writer.WriteStartArray(); + foreach (var argument in message.Arguments) + { + _payloadSerializer.Serialize(writer, argument); + } + writer.WriteEndArray(); + + writer.WriteEndObject(); + } + + private static void WriteHubMessageCommon(HubMessage message, JsonTextWriter writer, int type) + { + writer.WritePropertyName(InvocationIdPropertyName); + writer.WriteValue(message.InvocationId); + writer.WritePropertyName(TypePropertyName); + writer.WriteValue(type); + } + + private InvocationMessage BindInvocationMessage(JObject json, IInvocationBinder binder) + { + var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var target = GetRequiredProperty(json, TargetPropertyName, JTokenType.String); + var nonBlocking = GetOptionalProperty(json, NonBlockingPropertyName, JTokenType.Boolean); + + var args = GetRequiredProperty(json, ArgumentsPropertyName, JTokenType.Array); + + var paramTypes = binder.GetParameterTypes(target); + var arguments = new object[args.Count]; + if (paramTypes.Length != arguments.Length) + { + throw new FormatException($"Invocation provides {arguments.Length} argument(s) but target expects {paramTypes.Length}."); + } + + for (var i = 0; i < paramTypes.Length; i++) + { + var paramType = paramTypes[i]; + + // TODO(anurse): We can add some DI magic here to allow users to provide their own serialization + // Related Bug: https://github.com/aspnet/SignalR/issues/261 + arguments[i] = args[i].ToObject(paramType, _payloadSerializer); + } + + return new InvocationMessage(invocationId, nonBlocking, target, arguments); + } + + private StreamItemMessage BindResultMessage(JObject json, IInvocationBinder binder) + { + var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var result = GetRequiredProperty(json, ResultPropertyName); + + var returnType = binder.GetReturnType(invocationId); + return new StreamItemMessage(invocationId, result?.ToObject(returnType, _payloadSerializer)); + } + + private CompletionMessage BindCompletionMessage(JObject json, IInvocationBinder binder) + { + var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var error = GetOptionalProperty(json, ErrorPropertyName, JTokenType.String); + var resultProp = json.Property(ResultPropertyName); + + if (error != null && resultProp != null) + { + throw new FormatException("The 'error' and 'result' properties are mutually exclusive."); + } + + if (resultProp == null) + { + return new CompletionMessage(invocationId, error, result: null, hasResult: false); + } + else + { + var returnType = binder.GetReturnType(invocationId); + var payload = resultProp.Value?.ToObject(returnType, _payloadSerializer); + return new CompletionMessage(invocationId, error, result: payload, hasResult: true); + } + } + + private T GetOptionalProperty(JObject json, string property, JTokenType expectedType = JTokenType.None, T defaultValue = default(T)) + { + var prop = json[property]; + + if (prop == null) + { + return defaultValue; + } + + return GetValue(property, expectedType, prop); + } + + private T GetRequiredProperty(JObject json, string property, JTokenType expectedType = JTokenType.None) + { + var prop = json[property]; + + if (prop == null) + { + throw new FormatException($"Missing required property '{property}'."); + } + + return GetValue(property, expectedType, prop); + } + + private static T GetValue(string property, JTokenType expectedType, JToken prop) + { + if (expectedType != JTokenType.None && prop.Type != expectedType) + { + throw new FormatException($"Expected '{property}' to be of type {expectedType}."); + } + return prop.Value(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamItemMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamItemMessage.cs new file mode 100644 index 0000000000..dec753f948 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamItemMessage.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class StreamItemMessage : HubMessage + { + public object Item { get; } + + public StreamItemMessage(string invocationId, object item) : base(invocationId) + { + Item = item; + } + + public override string ToString() + { + return $"StreamItem {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Item)}: {Item ?? "<>"} }}"; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/InvocationAdapterExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/InvocationAdapterExtensions.cs deleted file mode 100644 index 8b33886b1b..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/InvocationAdapterExtensions.cs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR -{ - public static class InvocationAdapterExtensions - { - public static Task ReadMessageAsync(this IInvocationAdapter self, Stream stream, IInvocationBinder binder) => self.ReadMessageAsync(stream, binder, CancellationToken.None); - - public static Task WriteMessageAsync(this IInvocationAdapter self, InvocationMessage message, Stream stream) => self.WriteMessageAsync(message, stream, CancellationToken.None); - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/InvocationDescriptor.cs b/src/Microsoft.AspNetCore.SignalR.Common/InvocationDescriptor.cs deleted file mode 100644 index a15205192d..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/InvocationDescriptor.cs +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; - -namespace Microsoft.AspNetCore.SignalR -{ - public class InvocationDescriptor : InvocationMessage - { - public string Method { get; set; } - - public object[] Arguments { get; set; } - - public override string ToString() - { - return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})"; - } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/InvocationResultDescriptor.cs b/src/Microsoft.AspNetCore.SignalR.Common/InvocationResultDescriptor.cs deleted file mode 100644 index 32428a32c7..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/InvocationResultDescriptor.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR -{ - public class InvocationResultDescriptor : InvocationMessage - { - public object Result { get; set; } - - public string Error { get; set; } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs b/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs deleted file mode 100644 index 8c7a1a052d..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Internal; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - -namespace Microsoft.AspNetCore.SignalR -{ - public class JsonNetInvocationAdapter : IInvocationAdapter - { - private JsonSerializer _serializer = new JsonSerializer(); - - public JsonNetInvocationAdapter() - { - } - - public Task ReadMessageAsync(Stream stream, IInvocationBinder binder, CancellationToken cancellationToken) - { - var reader = new JsonTextReader(new StreamReader(stream)); - // REVIEW: Task.Run() - return Task.Run(() => - { - cancellationToken.ThrowIfCancellationRequested(); - var json = _serializer.Deserialize(reader); - if (json == null) - { - return null; - } - - // Determine the type of the message - if (json["Result"] != null) - { - // It's a result - return BindInvocationResultDescriptor(json, binder, cancellationToken); - } - else - { - return BindInvocationDescriptor(json, binder, cancellationToken); - } - }, cancellationToken); - } - - public Task WriteMessageAsync(InvocationMessage message, Stream stream, CancellationToken cancellationToken) - { - var writer = new JsonTextWriter(new StreamWriter(stream)); - _serializer.Serialize(writer, message); - writer.Flush(); - return TaskCache.CompletedTask; - } - - private InvocationDescriptor BindInvocationDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken) - { - var invocation = new InvocationDescriptor - { - Id = json.Value("Id"), - Method = json.Value("Method"), - }; - - var paramTypes = binder.GetParameterTypes(invocation.Method); - invocation.Arguments = new object[paramTypes.Length]; - - var args = json.Value("Arguments"); - for (var i = 0; i < paramTypes.Length; i++) - { - var paramType = paramTypes[i]; - invocation.Arguments[i] = args[i].ToObject(paramType, _serializer); - } - - return invocation; - } - - private InvocationResultDescriptor BindInvocationResultDescriptor(JObject json, IInvocationBinder binder, CancellationToken cancellationToken) - { - var id = json.Value("Id"); - var returnType = binder.GetReturnType(id); - var result = new InvocationResultDescriptor() - { - Id = id, - Result = returnType == null ? null : json["Result"].ToObject(returnType, _serializer), - Error = json.Value("Error") - }; - return result; - } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj index 9ac853b1a3..138643f8d1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj @@ -9,11 +9,20 @@ true aspnetcore;signalr false + Microsoft.AspNetCore.SignalR - + + + + + + + + + diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj index e60402fb46..551cab5dd5 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj @@ -13,7 +13,6 @@ - diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index d744de4130..cb847095a6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -5,53 +5,80 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; -using System.IO.Pipelines; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Newtonsoft.Json; using StackExchange.Redis; namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable { + private const string RedisSubscriptionsMetadataName = "redis_subscriptions"; + private readonly ConnectionList _connections = new ConnectionList(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); - private readonly InvocationAdapterRegistry _registry; private readonly ConnectionMultiplexer _redisServerConnection; private readonly ISubscriber _bus; - private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; private readonly RedisOptions _options; - public RedisHubLifetimeManager(InvocationAdapterRegistry registry, - ILoggerFactory loggerFactory, + // This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection. + private readonly JsonSerializer _serializer = new JsonSerializer + { + // We need to serialize objects "full-fidelity", even if it is noisy, so we preserve the original types + TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple, + TypeNameHandling = TypeNameHandling.All, + Formatting = Formatting.None + }; + + private long _nextInvocationId = 0; + + public RedisHubLifetimeManager(ILogger> logger, IOptions options) { - _loggerFactory = loggerFactory; - _registry = registry; + _logger = logger; _options = options.Value; - var writer = new LoggerTextWriter(loggerFactory.CreateLogger>()); + var writer = new LoggerTextWriter(logger); + _logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e)))); _redisServerConnection = _options.Connect(writer); + if (_redisServerConnection.IsConnected) + { + _logger.LogInformation("Connected to redis"); + } + else + { + // TODO: We could support reconnecting, like old SignalR does. + throw new InvalidOperationException("Connection to redis failed."); + } _bus = _redisServerConnection.GetSubscriber(); - var previousBroadcastTask = TaskCache.CompletedTask; + var previousBroadcastTask = Task.CompletedTask; - _bus.Subscribe(typeof(THub).FullName, async (c, data) => + var channelName = typeof(THub).FullName; + _logger.LogInformation("Subscribing to channel: {channel}", channelName); + _bus.Subscribe(channelName, async (c, data) => { await previousBroadcastTask; + _logger.LogTrace("Received message from redis channel {channel}", channelName); + + var message = DeserializeMessage(data); + + // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf var tasks = new List(_connections.Count); foreach (var connection in _connections) { - tasks.Add(WriteAsync(connection, data)); + tasks.Add(WriteAsync(connection, message)); } previousBroadcastTask = Task.WhenAll(tasks); @@ -60,80 +87,68 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task InvokeAllAsync(string methodName, object[] args) { - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(typeof(THub).FullName, message); } public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args) { - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(typeof(THub).FullName + "." + connectionId, message); } public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) { - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(typeof(THub).FullName + ".group." + groupName, message); } public override Task InvokeUserAsync(string userId, string methodName, object[] args) { - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(typeof(THub).FullName + ".user." + userId, message); } - private async Task PublishAsync(string channel, InvocationDescriptor message) + private async Task PublishAsync(string channel, HubMessage hubMessage) { - // TODO: What format?? - var invocationAdapter = _registry.GetInvocationAdapter("json"); - - // BAD - using (var ms = new MemoryStream()) + byte[] payload; + using (var stream = new MemoryStream()) + using (var writer = new JsonTextWriter(new StreamWriter(stream))) { - await invocationAdapter.WriteMessageAsync(message, ms); - - await _bus.PublishAsync(channel, ms.ToArray()); + _serializer.Serialize(writer, hubMessage); + await writer.FlushAsync(); + payload = stream.ToArray(); } + + _logger.LogTrace("Publishing message to redis channel {channel}", channel); + await _bus.PublishAsync(channel, payload); } public override Task OnConnectedAsync(Connection connection) { - var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet()); - var connectionTask = TaskCache.CompletedTask; - var userTask = TaskCache.CompletedTask; + var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); + var connectionTask = Task.CompletedTask; + var userTask = Task.CompletedTask; _connections.Add(connection); var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId; redisSubscriptions.Add(connectionChannel); - var previousConnectionTask = TaskCache.CompletedTask; + var previousConnectionTask = Task.CompletedTask; + _logger.LogInformation("Subscribing to connection channel: {channel}", connectionChannel); connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) => { await previousConnectionTask; - previousConnectionTask = WriteAsync(connection, data); + var message = DeserializeMessage(data); + + previousConnectionTask = WriteAsync(connection, message); }); @@ -142,14 +157,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name; redisSubscriptions.Add(userChannel); - var previousUserTask = TaskCache.CompletedTask; + var previousUserTask = Task.CompletedTask; // TODO: Look at optimizing (looping over connections checking for Name) userTask = _bus.SubscribeAsync(userChannel, async (c, data) => { await previousUserTask; - previousUserTask = WriteAsync(connection, data); + var message = DeserializeMessage(data); + + previousUserTask = WriteAsync(connection, message); }); } @@ -162,16 +179,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis var tasks = new List(); - var redisSubscriptions = connection.Metadata.Get>("redis_subscriptions"); + var redisSubscriptions = connection.Metadata.Get>(RedisSubscriptionsMetadataName); if (redisSubscriptions != null) { foreach (var subscription in redisSubscriptions) { + _logger.LogInformation("Unsubscribing from channel: {channel}", subscription); tasks.Add(_bus.UnsubscribeAsync(subscription)); } } - var groupNames = connection.Metadata.Get>("group"); + var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); if (groupNames != null) { @@ -190,7 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { var groupChannel = typeof(THub).FullName + ".group." + groupName; - var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet()); + var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); lock (groupNames) { @@ -210,8 +228,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - var previousTask = TaskCache.CompletedTask; + var previousTask = Task.CompletedTask; + _logger.LogInformation("Subscribing to group channel: {channel}", groupChannel); await _bus.SubscribeAsync(groupChannel, async (c, data) => { // Since this callback is async, we await the previous task then @@ -219,10 +238,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis // want to do concurrent writes to the outgoing connections await previousTask; + var message = DeserializeMessage(data); + var tasks = new List(group.Connections.Count); foreach (var groupConnection in group.Connections) { - tasks.Add(WriteAsync(groupConnection, data)); + tasks.Add(WriteAsync(groupConnection, message)); } previousTask = Task.WhenAll(tasks); @@ -244,7 +265,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - var groupNames = connection.Metadata.Get>("group"); + var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); if (groupNames != null) { lock (groupNames) @@ -260,6 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis if (group.Connections.Count == 0) { + _logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel); await _bus.UnsubscribeAsync(groupChannel); } } @@ -275,9 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis _redisServerConnection.Dispose(); } - private async Task WriteAsync(Connection connection, byte[] data) + private async Task WriteAsync(Connection connection, HubMessage hubMessage) { - var message = new Message(data, MessageType.Text, endOfMessage: true); + var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + var data = await protocol.WriteToArrayAsync(hubMessage); + var message = new Message(data, protocol.MessageType, endOfMessage: true); while (await connection.Transport.Output.WaitToWriteAsync()) { @@ -288,6 +312,23 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } + private string GetInvocationId() + { + var invocationId = Interlocked.Increment(ref _nextInvocationId); + return invocationId.ToString(); + } + + private HubMessage DeserializeMessage(RedisValue data) + { + HubMessage message; + using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data)))) + { + message = (HubMessage)_serializer.Deserialize(reader); + } + + return message; + } + private class LoggerTextWriter : TextWriter { private readonly ILogger _logger; diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index a2842441f0..c8b83fe4a2 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -3,43 +3,37 @@ using System; using System.Collections.Generic; -using System.IO; -using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.SignalR { public class DefaultHubLifetimeManager : HubLifetimeManager { + private long _nextInvocationId = 0; private readonly ConnectionList _connections = new ConnectionList(); - private readonly InvocationAdapterRegistry _registry; - - public DefaultHubLifetimeManager(InvocationAdapterRegistry registry) - { - _registry = registry; - } public override Task AddGroupAsync(Connection connection, string groupName) { - var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); + var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); lock (groups) { groups.Add(groupName); } - return TaskCache.CompletedTask; + return Task.CompletedTask; } public override Task RemoveGroupAsync(Connection connection, string groupName) { - var groups = connection.Metadata.Get>("groups"); + var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); if (groups == null) { - return TaskCache.CompletedTask; + return Task.CompletedTask; } lock (groups) @@ -47,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR groups.Remove(groupName); } - return TaskCache.CompletedTask; + return Task.CompletedTask; } public override Task InvokeAllAsync(string methodName, object[] args) @@ -58,11 +52,7 @@ namespace Microsoft.AspNetCore.SignalR private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); // TODO: serialize once per format by providing a different stream? foreach (var connection in _connections) @@ -72,9 +62,7 @@ namespace Microsoft.AspNetCore.SignalR continue; } - var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); - - tasks.Add(WriteAsync(connection, invocationAdapter, message)); + tasks.Add(WriteAsync(connection, message)); } return Task.WhenAll(tasks); @@ -84,22 +72,16 @@ namespace Microsoft.AspNetCore.SignalR { var connection = _connections[connectionId]; - var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); - var message = new InvocationDescriptor - { - Method = methodName, - Arguments = args - }; - - return WriteAsync(connection, invocationAdapter, message); + return WriteAsync(connection, message); } public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) { return InvokeAllWhere(methodName, args, connection => { - var groups = connection.Metadata.Get>("groups"); + var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); return groups?.Contains(groupName) == true; }); } @@ -115,21 +97,20 @@ namespace Microsoft.AspNetCore.SignalR public override Task OnConnectedAsync(Connection connection) { _connections.Add(connection); - return TaskCache.CompletedTask; + return Task.CompletedTask; } public override Task OnDisconnectedAsync(Connection connection) { _connections.Remove(connection); - return TaskCache.CompletedTask; + return Task.CompletedTask; } - private static async Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocation) + private async Task WriteAsync(Connection connection, HubMessage hubMessage) { - var stream = new MemoryStream(); - await invocationAdapter.WriteMessageAsync(invocation, stream); - - var message = new Message(stream.ToArray(), MessageType.Text, endOfMessage: true); + var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + var payload = await protocol.WriteToArrayAsync(hubMessage); + var message = new Message(payload, protocol.MessageType, endOfMessage: true); while (await connection.Transport.Output.WaitToWriteAsync()) { @@ -139,5 +120,11 @@ namespace Microsoft.AspNetCore.SignalR } } } + + private string GetInvocationId() + { + var invocationId = Interlocked.Increment(ref _nextInvocationId); + return invocationId.ToString(); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR/Hub.cs b/src/Microsoft.AspNetCore.SignalR/Hub.cs index 5ff206d408..3ecc69bf51 100644 --- a/src/Microsoft.AspNetCore.SignalR/Hub.cs +++ b/src/Microsoft.AspNetCore.SignalR/Hub.cs @@ -3,7 +3,6 @@ using System; using System.Threading.Tasks; -using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.SignalR { @@ -62,12 +61,12 @@ namespace Microsoft.AspNetCore.SignalR public virtual Task OnConnectedAsync() { - return TaskCache.CompletedTask; + return Task.CompletedTask; } public virtual Task OnDisconnectedAsync(Exception exception) { - return TaskCache.CompletedTask; + return Task.CompletedTask; } protected virtual void Dispose(bool disposing) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs similarity index 54% rename from src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs rename to src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs index fc0c1bb5ca..702a762c01 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/InvocationMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs @@ -1,15 +1,11 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - namespace Microsoft.AspNetCore.SignalR { - public abstract class InvocationMessage + public static class HubConnectionMetadataNames { - public string Id { get; set; } + public static readonly string HubProtocol = nameof(HubProtocol); + public static readonly string Groups = nameof(Groups); } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index bbe5f6717c..82af05346f 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -1,13 +1,14 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; @@ -19,12 +20,12 @@ namespace Microsoft.AspNetCore.SignalR public class HubEndPoint : HubEndPoint where THub : Hub { public HubEndPoint(HubLifetimeManager lifetimeManager, + IHubProtocolResolver protocolResolver, IHubContext hubContext, - InvocationAdapterRegistry registry, IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) - : base(lifetimeManager, hubContext, registry, endPointOptions, logger, serviceScopeFactory) + : base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory) { } } @@ -36,19 +37,19 @@ namespace Microsoft.AspNetCore.SignalR private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; private readonly ILogger> _logger; - private readonly InvocationAdapterRegistry _registry; private readonly IServiceScopeFactory _serviceScopeFactory; + private readonly IHubProtocolResolver _protocolResolver; public HubEndPoint(HubLifetimeManager lifetimeManager, + IHubProtocolResolver protocolResolver, IHubContext hubContext, - InvocationAdapterRegistry registry, IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { + _protocolResolver = protocolResolver; _lifetimeManager = lifetimeManager; _hubContext = hubContext; - _registry = registry; _logger = logger; _serviceScopeFactory = serviceScopeFactory; @@ -59,6 +60,11 @@ namespace Microsoft.AspNetCore.SignalR { try { + // Resolve the Hub Protocol for the connection and store it in metadata + // Other components, outside the Hub, may need to know what protocol is in use + // for a particular connection, so we store it here. + connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection); + await _lifetimeManager.OnConnectedAsync(connection); await RunHubAsync(connection); } @@ -140,14 +146,13 @@ namespace Microsoft.AspNetCore.SignalR private async Task DispatchMessagesAsync(Connection connection) { - var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); - // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The // cancellation token is used to stop reading from the channel, the tcs // is used to get the exception so we can bubble it up the stack var cts = new CancellationTokenSource(); - var tcs = new TaskCompletionSource(); + var completion = new TaskCompletionSource(); + var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); try { @@ -155,100 +160,94 @@ namespace Microsoft.AspNetCore.SignalR { while (connection.Transport.Input.TryRead(out var incomingMessage)) { - InvocationDescriptor invocationDescriptor; - var inputStream = new MemoryStream(incomingMessage.Payload); + var hubMessage = protocol.ParseMessage(incomingMessage.Payload, this); - // TODO: Handle receiving InvocationResultDescriptor - invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor; - - // Is there a better way of detecting that a connection was closed? - if (invocationDescriptor == null) + switch (hubMessage) { - break; - } + case InvocationMessage invocationMessage: + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.LogDebug("Received hub invocation: {invocation}", invocationMessage); + } - if (_logger.IsEnabled(LogLevel.Debug)) - { - _logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor); - } + // Don't wait on the result of execution, continue processing other + // incoming messages on this connection. + var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion); + break; - // Don't wait on the result of execution, continue processing other - // incoming messages on this connection. - var ignore = ProcessInvocation(connection, invocationAdapter, invocationDescriptor, cts, tcs); + // Other kind of message we weren't expecting + default: + _logger.LogError("Received unsupported message of type '{messageType}'", hubMessage.GetType().FullName); + throw new NotSupportedException($"Received unsupported message: {hubMessage}"); + } } } } catch (OperationCanceledException) { // Await the task so the exception bubbles up to the caller - await tcs.Task; + await completion.Task; } } private async Task ProcessInvocation(Connection connection, - IInvocationAdapter invocationAdapter, - InvocationDescriptor invocationDescriptor, - CancellationTokenSource cts, - TaskCompletionSource tcs) + IHubProtocol protocol, + InvocationMessage invocationMessage, + CancellationTokenSource dispatcherCancellation, + TaskCompletionSource dispatcherCompletion) { try { // If an unexpected exception occurs then we want to kill the entire connection // by ending the processing loop - await Execute(connection, invocationAdapter, invocationDescriptor); + await Execute(connection, protocol, invocationMessage); } catch (Exception ex) { // Set the exception on the task completion source - tcs.TrySetException(ex); + dispatcherCompletion.TrySetException(ex); // Cancel reading operation - cts.Cancel(); + dispatcherCancellation.Cancel(); } } - private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor) + private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage) { - InvocationResultDescriptor result; HubMethodDescriptor descriptor; - if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor)) + if (!_methods.TryGetValue(invocationMessage.Target, out descriptor)) { - result = await Invoke(descriptor, connection, invocationDescriptor); + // Send an error to the client. Then let the normal completion process occur + _logger.LogError("Unknown hub method '{method}'", invocationMessage.Target); + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'")); } else { - // If there's no method then return a failed response for this request - result = new InvocationResultDescriptor - { - Id = invocationDescriptor.Id, - Error = $"Unknown hub method '{invocationDescriptor.Method}'" - }; - - _logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method); - } - - // TODO: Pool memory - var outStream = new MemoryStream(); - await invocationAdapter.WriteMessageAsync(result, outStream); - - var outMessage = new Message(outStream.ToArray(), MessageType.Text, endOfMessage: true); - - while (await connection.Transport.Output.WaitToWriteAsync()) - { - if (connection.Transport.Output.TryWrite(outMessage)) - { - break; - } + var result = await Invoke(descriptor, connection, invocationMessage); + await SendMessageAsync(connection, protocol, result); } } - private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor) + private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage) { - var invocationResult = new InvocationResultDescriptor - { - Id = invocationDescriptor.Id - }; + var payload = await protocol.WriteToArrayAsync(hubMessage); + var message = new Message(payload, protocol.MessageType, endOfMessage: true); + while (await connection.Transport.Output.WaitToWriteAsync()) + { + if (connection.Transport.Output.TryWrite(message)) + { + return; + } + } + + // Output is closed. Cancel this invocation completely + _logger.LogWarning("Outbound channel was closed while trying to write hub message"); + throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); + } + + private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage) + { var methodExecutor = descriptor.MethodExecutor; using (var scope = _serviceScopeFactory.CreateScope()) @@ -265,37 +264,35 @@ namespace Microsoft.AspNetCore.SignalR { if (methodExecutor.MethodReturnType == typeof(Task)) { - await (Task)methodExecutor.Execute(hub, invocationDescriptor.Arguments); + await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments); } else { - result = await methodExecutor.ExecuteAsync(hub, invocationDescriptor.Arguments); + result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments); } } else { - result = methodExecutor.Execute(hub, invocationDescriptor.Arguments); + result = methodExecutor.Execute(hub, invocationMessage.Arguments); } - invocationResult.Result = result; + return CompletionMessage.WithResult(invocationMessage.InvocationId, result); } catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - invocationResult.Error = ex.InnerException.Message; + return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message); } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - invocationResult.Error = ex.Message; + return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message); } finally { hubActivator.Release(hub); } } - - return invocationResult; } private void InitializeHub(THub hub, Connection connection) diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs new file mode 100644 index 0000000000..2d06b8bcee --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public class DefaultHubProtocolResolver : IHubProtocolResolver + { + public IHubProtocol GetProtocol(Connection connection) + { + // TODO: Allow customization of this serializer! + return new JsonHubProtocol(new JsonSerializer()); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs new file mode 100644 index 0000000000..c9627d0d59 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public interface IHubProtocolResolver + { + IHubProtocol GetProtocol(Connection connection); + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs b/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs deleted file mode 100644 index 112b8d9345..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/InvocationAdapterRegistry.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; - -namespace Microsoft.AspNetCore.SignalR -{ - public class InvocationAdapterRegistry - { - private readonly IServiceProvider _serviceProvider; - private readonly SignalROptions _options; - - public InvocationAdapterRegistry(IOptions options, IServiceProvider serviceProvider) - { - _options = options.Value; - _serviceProvider = serviceProvider; - } - - public IInvocationAdapter GetInvocationAdapter(string format) - { - Type type; - if (_options._invocationMappings.TryGetValue(format, out type)) - { - return _serviceProvider.GetRequiredService(type) as IInvocationAdapter; - } - - return null; - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj index 83080985e8..aedd17a019 100644 --- a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj +++ b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj @@ -14,7 +14,6 @@ - diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index bceb79c04f..71a5f022ef 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -1,9 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using Microsoft.AspNetCore.SignalR; -using Microsoft.Extensions.Options; +using Microsoft.AspNetCore.SignalR.Internal; namespace Microsoft.Extensions.DependencyInjection { @@ -13,26 +12,13 @@ namespace Microsoft.Extensions.DependencyInjection { services.AddSockets(); services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>)); + services.AddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver)); services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>)); services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); - services.AddSingleton, SignalROptionsSetup>(); - services.AddSingleton(); - services.AddSingleton(); services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>)); services.AddRouting(); return new SignalRBuilder(services); } - - public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action setupAction) - { - return services.AddSignalR().AddSignalROptions(setupAction); - } - - public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action setupAction) - { - builder.Services.Configure(setupAction); - return builder; - } } } diff --git a/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs b/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs deleted file mode 100644 index 65a00b85ca..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/SignalROptions.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR -{ - public class SignalROptions - { - internal readonly Dictionary _invocationMappings = new Dictionary(); - - public void RegisterInvocationAdapter(string format) where TInvocationAdapter : IInvocationAdapter - { - _invocationMappings[format] = typeof(TInvocationAdapter); - } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs b/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs deleted file mode 100644 index 546d621437..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/SignalROptionsSetup.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Microsoft.AspNetCore.SignalR; -using Microsoft.Extensions.Options; - -namespace Microsoft.Extensions.DependencyInjection -{ - public class SignalROptionsSetup : IConfigureOptions - { - public void Configure(SignalROptions options) - { - options.RegisterInvocationAdapter("json"); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index 6d69d094ed..f3c3576fec 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -10,7 +10,6 @@ using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Internal.Formatters; -using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -58,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets.Client return t; }).Unwrap(); - return TaskCache.CompletedTask; + return Task.CompletedTask; } public async Task StopAsync() diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj index 66031fa62c..3e236ac65f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj @@ -17,7 +17,6 @@ - diff --git a/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs index 981247100b..14a0d9792c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/ServerSentEventsTransport.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,7 +8,6 @@ using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Internal.Formatters; -using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -61,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets.Client return t; }).Unwrap(); - return TaskCache.CompletedTask; + return Task.CompletedTask; } private async Task OpenConnection(IChannelConnection application, Uri url, CancellationToken cancellationToken) diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj index 0e12b5c4cc..ebb3ced79e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj @@ -12,6 +12,10 @@ false + + + + diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs new file mode 100644 index 0000000000..646673782a --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadataNames.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Sockets +{ + public static class ConnectionMetadataNames + { + public static readonly string Format = nameof(Format); + public static readonly string Transport = nameof(Transport); + public static readonly string HttpContext = nameof(HttpContext); + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs similarity index 93% rename from src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs rename to src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs index a83814c326..c09c6118fd 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndpointDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs @@ -6,7 +6,7 @@ using Microsoft.AspNetCore.Sockets; namespace Microsoft.Extensions.DependencyInjection { - public static class EndpointDependencyInjectionExtensions + public static class EndPointDependencyInjectionExtensions { public static IServiceCollection AddEndPoint(this IServiceCollection services) where TEndPoint : EndPoint { diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 0be2cf9ad5..1edd8b71be 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Sockets { _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); - state.Connection.Metadata["transport"] = TransportType.LongPolling; + state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); } @@ -232,13 +232,10 @@ namespace Microsoft.AspNetCore.Sockets private ConnectionState CreateConnection(HttpContext context) { var state = _manager.CreateConnection(); + var format = (string)context.Request.Query[ConnectionMetadataNames.Format]; state.Connection.User = context.User; - - // TODO: this is wrong. + how does the user add their own metadata based on HttpContext - var formatType = (string)context.Request.Query["formatType"]; - state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType; - state.Connection.Metadata[typeof(HttpContext)] = context; - + state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context; + state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format; return state; } @@ -355,7 +352,7 @@ namespace Microsoft.AspNetCore.Sockets var messages = ParseSendBatch(ref reader, messageFormat); // REVIEW: Do we want to return a specific status code here if the connection has ended? - _logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count); + _logger.LogDebug("Received batch of {count} message(s) in '/send'", messages.Count); foreach (var message in messages) { while (!state.Application.Output.TryWrite(message)) @@ -379,11 +376,11 @@ namespace Microsoft.AspNetCore.Sockets connectionState.Connection.User = context.User; - var transport = connectionState.Connection.Metadata.Get("transport"); + var transport = connectionState.Connection.Metadata.Get(ConnectionMetadataNames.Transport); if (transport == null) { - connectionState.Connection.Metadata["transport"] = transportType; + connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType; } else if (transport != transportType) { diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index 6e3bd605bb..c73423b87f 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal public async Task DisposeAsync() { - Task disposeTask = TaskCache.CompletedTask; + Task disposeTask = Task.CompletedTask; try { @@ -69,8 +69,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal Connection.Dispose(); Application.Dispose(); - var applicationTask = ApplicationTask ?? TaskCache.CompletedTask; - var transportTask = TransportTask ?? TaskCache.CompletedTask; + var applicationTask = ApplicationTask ?? Task.CompletedTask; + var transportTask = TransportTask ?? Task.CompletedTask; disposeTask = WaitOnTasks(applicationTask, transportTask); } diff --git a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj index 0247ea2f83..fc9ba09d5f 100644 --- a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj +++ b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj @@ -18,7 +18,6 @@ - diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index 1a3fe2519b..bd2c8347bf 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -134,7 +134,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports // Is this a frame we care about? if (!frame.Opcode.IsMessage()) { - return TaskCache.CompletedTask; + return Task.CompletedTask; } LogFrame("Receiving", frame); diff --git a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs index 88a88c07b7..001437eb3c 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs @@ -4,7 +4,6 @@ using System; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Internal; namespace Microsoft.Extensions.WebSockets.Internal { @@ -135,7 +134,7 @@ namespace Microsoft.Extensions.WebSockets.Internal connection.ExecuteAsync((frame, _) => { messageHandler(frame); - return TaskCache.CompletedTask; + return Task.CompletedTask; }, null); /// @@ -149,7 +148,7 @@ namespace Microsoft.Extensions.WebSockets.Internal connection.ExecuteAsync((frame, s) => { messageHandler(frame, s); - return TaskCache.CompletedTask; + return Task.CompletedTask; }, state); /// diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj index 6d15f40632..d886a6890c 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj +++ b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj @@ -17,7 +17,6 @@ - diff --git a/test/Common/TaskExtensions.cs b/test/Common/TaskExtensions.cs index 51c7b549e4..10b31649bc 100644 --- a/test/Common/TaskExtensions.cs +++ b/test/Common/TaskExtensions.cs @@ -1,7 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Runtime.CompilerServices; using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Tests.Common @@ -10,36 +11,48 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common { private const int DefaultTimeout = 5000; - public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout) + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds)); + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); } - public static async Task OrTimeout(this Task task, TimeSpan timeout) + public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { var completed = await Task.WhenAny(task, Task.Delay(timeout)); if (completed != task) { - throw new TimeoutException(); + throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } await task; } - public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout) + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds)); + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); } - public static async Task OrTimeout(this Task task, TimeSpan timeout) + public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { var completed = await Task.WhenAny(task, Task.Delay(timeout)); if (completed != task) { - throw new TimeoutException(); + throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } return await task; } + + private static string GetMessage(string memberName, string filePath, int? lineNumber) + { + if (!string.IsNullOrEmpty(memberName)) + { + return $"Operation in {memberName} timed out at {filePath}:{lineNumber}"; + } + else + { + return "Operation timed out"; + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 51cf710546..24e5a3425f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests using (var httpClient = _testServer.CreateClient()) { - var connection = new HubConnection(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), loggerFactory); + var connection = new HubConnection(new Uri("http://test/hubs"), loggerFactory); try { await connection.StartAsync(TransportType.LongPolling, httpClient); @@ -133,13 +133,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests tcs.TrySetResult((string)a[0]); }); - await connection.Invoke("CallEcho", originalMessage); + await connection.Invoke("CallEcho", originalMessage).OrTimeout(); Assert.Equal(originalMessage, await tcs.Task.OrTimeout()); } finally { - await connection.DisposeAsync(); + await connection.DisposeAsync().OrTimeout(); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs new file mode 100644 index 0000000000..de53017a80 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -0,0 +1,156 @@ +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + // This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer). + // We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks + // don't cause problems. + public class HubConnectionProtocolTests + { + [Fact] + public async Task InvokeSendsAnInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.Invoke("Foo", typeof(void)); + + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeCompletedWhenCompletionMessageReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.Invoke("Foo", typeof(void)); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + + await invokeTask.OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeYieldsResultWhenCompletionMessageReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.Invoke("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout(); + + Assert.Equal(42, await invokeTask.OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.Invoke("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("An error occurred", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + // This will fail (intentionally) when we support streaming! + public async Task InvokeFailsWithErrorWhenStreamingItemReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.Invoke("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, result = 42 }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("Streaming method results are not supported", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + var handlerCalled = new TaskCompletionSource(); + try + { + await hubConnection.StartAsync(); + + hubConnection.On("Foo", new[] { typeof(int), typeof(string), typeof(float) }, (a) => handlerCalled.TrySetResult(a)); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 1, "Foo", 2.0f } }).OrTimeout(); + + var results = await handlerCalled.Task.OrTimeout(); + Assert.Equal(3, results.Length); + Assert.Equal(1, results[0]); + Assert.Equal("Foo", results[1]); + Assert.Equal(2.0f, results[2]); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 3298a26e03..0755c598e4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -2,12 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.IO; +using System.Buffers; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -25,14 +26,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public void CannotCreateHubConnectionWithNullUrl() { var exception = Assert.Throws( - () => new HubConnection((Uri)null, Mock.Of(), Mock.Of())); + () => new HubConnection((Uri)null, Mock.Of())); Assert.Equal("url", exception.ParamName); } [Fact] public async Task CanDisposeNotStartedHubConnection() { - await new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory()) + await new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory()) .DisposeAsync(); } @@ -50,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory()); try { @@ -81,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(new Uri("http://fakeuri.org/"), new LoggerFactory()); await hubConnection.StartAsync(TransportType.LongPolling, httpClient); await hubConnection.DisposeAsync(); @@ -110,7 +111,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var exception = await Assert.ThrowsAsync(async () => await hubConnection.Invoke("test")); - Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); + Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); } [Fact] @@ -119,12 +120,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var mockConnection = new Mock(); var exception = new InvalidOperationException(); - var mockInvocationAdapter = new Mock(); - mockInvocationAdapter - .Setup(a => a.WriteMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(Task.FromException(exception)); - - var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null); + var mockProtocol = MockHubProtocol.Throw(exception); + var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null); await hubConnection.StartAsync(TransportType.All, httpClient: null); var actualException = @@ -146,7 +143,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory()); try { var connectedEventRaisedTcs = new TaskCompletionSource(); @@ -177,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(new Uri("http://fakeuri.org"), new LoggerFactory()); var closedEventTcs = new TaskCompletionSource(); hubConnection.Closed += e => closedEventTcs.SetResult(e); @@ -197,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null)) .Returns(Task.FromResult(null)); - var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory()); await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null); await hubConnection.DisposeAsync(); @@ -216,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Callback(() => mockConnection.Raise(c => c.Closed += null, (Exception)null)) .Returns(Task.FromResult(null)); - var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory()); await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null); var invokeTask = hubConnection.Invoke("testMethod", typeof(int)); @@ -235,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Callback(() => mockConnection.Raise(c => c.Closed += null, exception)) .Returns(Task.FromResult(null)); - var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory()); + var hubConnection = new HubConnection(mockConnection.Object, new LoggerFactory()); await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null); var invokeTask = hubConnection.Invoke("testMethod", typeof(int)); @@ -250,23 +247,69 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var mockConnection = new Mock(); - var invocationDescriptor = new InvocationDescriptor - { - Method = "NonExistingMethod123", - Arguments = new object[] { true, "arg2", 123 }, - Id = Guid.NewGuid().ToString() - }; + var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 }); - var mockInvocationAdapter = new Mock(); - mockInvocationAdapter - .Setup(a => a.ReadMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(Task.FromResult((InvocationMessage)invocationDescriptor)); + var mockProtocol = MockHubProtocol.ReturnOnParse(invocation); - var hubConnection = new HubConnection(mockConnection.Object, mockInvocationAdapter.Object, null); + var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null); await hubConnection.StartAsync(new TestTransportFactory(Mock.Of()), httpClient: null); mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { }, MessageType.Text }); - mockInvocationAdapter.Verify(a => a.ReadMessageAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); + Assert.Equal(1, mockProtocol.ParseCalls); + } + + // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse + private class MockHubProtocol : IHubProtocol + { + private HubMessage _parsed; + private Exception _error; + + public int ParseCalls { get; private set; } = 0; + public int WriteCalls { get; private set; } = 0; + + public MessageType MessageType => MessageType.Text; + + public static MockHubProtocol ReturnOnParse(HubMessage parsed) + { + return new MockHubProtocol + { + _parsed = parsed + }; + } + + public static MockHubProtocol Throw(Exception error) + { + return new MockHubProtocol + { + _error = error + }; + } + + public HubMessage ParseMessage(ReadOnlySpan input, IInvocationBinder binder) + { + ParseCalls += 1; + if (_error != null) + { + throw _error; + } + if (_parsed != null) + { + return _parsed; + } + + throw new InvalidOperationException("No Parsed Message provided"); + } + + public bool TryWriteMessage(HubMessage message, IOutput output) + { + WriteCalls += 1; + + if (_error != null) + { + throw _error; + } + return true; + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs new file mode 100644 index 0000000000..8a5eb49675 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -0,0 +1,116 @@ +using System; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Client; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + internal class TestConnection : IConnection + { + private TaskCompletionSource _started = new TaskCompletionSource(); + private TaskCompletionSource _disposed = new TaskCompletionSource(); + + private Channel _sentMessages = Channel.CreateUnbounded(); + private Channel _receivedMessages = Channel.CreateUnbounded(); + + private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource(); + private Task _receiveLoop; + + public event Action Connected; + public event Action Received; + public event Action Closed; + + public Task Started => _started.Task; + public Task Disposed => _disposed.Task; + public ReadableChannel SentMessages => _sentMessages.In; + public WritableChannel ReceivedMessages => _receivedMessages.Out; + + public TestConnection() + { + _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token); + } + + public Task DisposeAsync() + { + _disposed.TrySetResult(null); + _receiveShutdownToken.Cancel(); + return _receiveLoop; + } + + public async Task SendAsync(byte[] data, MessageType type, CancellationToken cancellationToken) + { + if(!_started.Task.IsCompleted) + { + throw new InvalidOperationException("Connection must be started before SendAsync can be called"); + } + + var message = new Message(data, type, endOfMessage: true); + while (await _sentMessages.Out.WaitToWriteAsync(cancellationToken)) + { + if (_sentMessages.Out.TryWrite(message)) + { + return; + } + } + throw new ObjectDisposedException("Unable to send message, underlying channel was closed"); + } + + public Task StartAsync(ITransportFactory transportFactory, HttpClient httpClient) + { + _started.TrySetResult(null); + Connected?.Invoke(); + return Task.CompletedTask; + } + + public async Task ReadSentTextMessageAsync() + { + var message = await SentMessages.ReadAsync(); + if (message.Type != MessageType.Text) + { + throw new InvalidOperationException($"Unexpected message of type: {message.Type}"); + } + return Encoding.UTF8.GetString(message.Payload); + } + + public Task ReceiveJsonMessage(object jsonObject) + { + var json = JsonConvert.SerializeObject(jsonObject, Formatting.None); + var bytes = Encoding.UTF8.GetBytes(json); + var message = new Message(bytes, MessageType.Text); + + return _receivedMessages.Out.WriteAsync(message); + } + + private async Task ReceiveLoopAsync(CancellationToken token) + { + try + { + while (!token.IsCancellationRequested) + { + while (await _receivedMessages.In.WaitToReadAsync(token)) + { + while (_receivedMessages.In.TryRead(out var message)) + { + Received?.Invoke(message.Payload, message.Type); + } + } + } + Closed?.Invoke(null); + } + catch (OperationCanceledException) + { + // Do nothing, we were just asked to shut down. + Closed?.Invoke(null); + } + catch (Exception ex) + { + Closed?.Invoke(ex); + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs new file mode 100644 index 0000000000..6dc6fb4a0d --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -0,0 +1,257 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol +{ + public class JsonHubProtocolTests + { + public static IEnumerable ProtocolTestData => new[] + { + new object[] { new InvocationMessage("123", true, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"nonBlocking\":true,\"arguments\":[1,\"Foo\",2.0]}" }, + new object[] { new InvocationMessage("123", false, "Target", 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" }, + new object[] { new InvocationMessage("123", false, "Target", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[true]}" }, + new object[] { new InvocationMessage("123", false, "Target", new object[] { null }), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[null]}" }, + new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}]}" }, + new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}]}" }, + new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}]}" }, + new object[] { new InvocationMessage("123", false, "Target", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}]}" }, + + new object[] { new StreamItemMessage("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":1}" }, + new object[] { new StreamItemMessage("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":\"Foo\"}" }, + new object[] { new StreamItemMessage("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":2.0}" }, + new object[] { new StreamItemMessage("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":true}" }, + new object[] { new StreamItemMessage("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":null}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" }, + new object[] { new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":2,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" }, + + new object[] { CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":1}" }, + new object[] { CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":\"Foo\"}" }, + new object[] { CompletionMessage.WithResult("123", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":2.0}" }, + new object[] { CompletionMessage.WithResult("123", true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":true}" }, + new object[] { CompletionMessage.WithResult("123", null), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":null}" }, + new object[] { CompletionMessage.WithError("123", "Whoops!"), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"error\":\"Whoops!\"}" }, + new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\"}}" }, + new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null}}" }, + new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null}}" }, + }; + + [Theory] + [MemberData(nameof(ProtocolTestData))] + public async Task WriteMessage(HubMessage message, bool camelCase, NullValueHandling nullValueHandling, string expectedOutput) + { + var jsonSerializer = new JsonSerializer + { + NullValueHandling = nullValueHandling, + ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver() + }; + + var protocol = new JsonHubProtocol(jsonSerializer); + var encoded = await protocol.WriteToArrayAsync(message); + var json = Encoding.UTF8.GetString(encoded); + + Assert.Equal(expectedOutput, json); + } + + [Theory] + [MemberData(nameof(ProtocolTestData))] + public void ParseMessage(HubMessage expectedMessage, bool camelCase, NullValueHandling nullValueHandling, string input) + { + var jsonSerializer = new JsonSerializer + { + NullValueHandling = nullValueHandling, + ContractResolver = camelCase ? new CamelCasePropertyNamesContractResolver() : new DefaultContractResolver() + }; + + var binder = new TestBinder(expectedMessage); + var protocol = new JsonHubProtocol(jsonSerializer); + var message = protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder); + + Assert.Equal(expectedMessage, message, TestEqualityComparer.Instance); + } + + [Theory] + [InlineData("", "Error reading JSON.")] + [InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] + [InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] + [InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] + [InlineData("[42]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] + [InlineData("{}", "Missing required property 'type'.")] + + [InlineData("{'type':1}", "Missing required property 'invocationId'.")] + [InlineData("{'type':1,'invocationId':42}", "Expected 'invocationId' to be of type String.")] + [InlineData("{'type':1,'invocationId':'42','target':42}", "Expected 'target' to be of type String.")] + [InlineData("{'type':1,'invocationId':'42','target':'foo'}", "Missing required property 'arguments'.")] + [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':{}}", "Expected 'arguments' to be of type Array.")] + + [InlineData("{'type':2}", "Missing required property 'invocationId'.")] + [InlineData("{'type':2,'invocationId':42}", "Expected 'invocationId' to be of type String.")] + [InlineData("{'type':2,'invocationId':'42'}", "Missing required property 'result'.")] + + [InlineData("{'type':3}", "Missing required property 'invocationId'.")] + [InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")] + [InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")] + + [InlineData("{'type':4}", "Unknown message type: 4")] + [InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")] + public void InvalidMessages(string input, string expectedMessage) + { + var binder = new TestBinder(); + var protocol = new JsonHubProtocol(new JsonSerializer()); + var ex = Assert.Throws(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder)); + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")] + [InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[42, 'foo'],'nonBlocking':42}", "Expected 'nonBlocking' to be of type Boolean.")] + [InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")] + public void InvalidMessagesWithBinder(string input, string expectedMessage) + { + var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool)); + var protocol = new JsonHubProtocol(new JsonSerializer()); + var ex = Assert.Throws(() => protocol.ParseMessage(Encoding.UTF8.GetBytes(input), binder)); + Assert.Equal(expectedMessage, ex.Message); + } + + private class CustomObject : IEquatable + { + // Not intended to be a full set of things, just a smattering of sample serializations + public string StringProp => "SignalR!"; + + public double DoubleProp => 6.2831853071; + + public int IntProp => 42; + + public DateTime DateTimeProp => new DateTime(2017, 4, 11); + + public object NullProp => null; + + public override bool Equals(object obj) + { + return obj is CustomObject o && Equals(o); + } + + public override int GetHashCode() + { + // This is never used in a hash table + return 0; + } + + public bool Equals(CustomObject right) + { + // This allows the comparer below to properly compare the object in the test. + return string.Equals(StringProp, right.StringProp, StringComparison.Ordinal) && + DoubleProp == right.DoubleProp && + IntProp == right.IntProp && + DateTime.Equals(DateTimeProp, right.DateTimeProp) && + NullProp == right.NullProp; + } + } + + // Binder that works based on the expected message argument/result types :) + private class TestBinder : IInvocationBinder + { + private readonly Type[] _paramTypes; + private readonly Type _returnType; + + public TestBinder(HubMessage expectedMessage) + { + switch(expectedMessage) + { + case InvocationMessage i: + _paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray(); + break; + case StreamItemMessage s: + _returnType = s.Item?.GetType() ?? typeof(object); + break; + case CompletionMessage c: + _returnType = c.Result?.GetType() ?? typeof(object); + break; + } + } + + public TestBinder() : this(null, null) { } + public TestBinder(Type[] paramTypes) : this(paramTypes, null) { } + public TestBinder(Type returnType) : this(null, returnType) {} + public TestBinder(Type[] paramTypes, Type returnType) + { + _paramTypes = paramTypes; + _returnType = returnType; + } + + public Type[] GetParameterTypes(string methodName) + { + if (_paramTypes != null) + { + return _paramTypes; + } + throw new InvalidOperationException("Unexpected binder call"); + } + + public Type GetReturnType(string invocationId) + { + if (_returnType != null) + { + return _returnType; + } + throw new InvalidOperationException("Unexpected binder call"); + } + } + + private class TestEqualityComparer : IEqualityComparer + { + public static readonly TestEqualityComparer Instance = new TestEqualityComparer(); + + private TestEqualityComparer() { } + + public bool Equals(HubMessage x, HubMessage y) + { + if (!string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal)) + { + return false; + } + + return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y); + } + + private bool CompletionMessagesEqual(HubMessage x, HubMessage y) + { + return x is CompletionMessage left && y is CompletionMessage right && + string.Equals(left.Error, right.Error, StringComparison.Ordinal) && + Equals(left.Result, right.Result) && + left.HasResult == right.HasResult; + } + + private bool StreamItemMessagesEqual(HubMessage x, HubMessage y) + { + return x is StreamItemMessage left && y is StreamItemMessage right && + Equals(left.Item, right.Item); + } + + private bool InvocationMessagesEqual(HubMessage x, HubMessage y) + { + return x is InvocationMessage left && y is InvocationMessage right && + string.Equals(left.Target, right.Target, StringComparison.Ordinal) && + Enumerable.SequenceEqual(left.Arguments, right.Arguments) && + left.NonBlocking == right.NonBlocking; + } + + public int GetHashCode(HubMessage obj) + { + // We never use these in a hash-table + return 0; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj new file mode 100644 index 0000000000..d06c98408d --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -0,0 +1,31 @@ + + + + + + netcoreapp2.0;net46 + netcoreapp2.0 + + true + true + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 84af03faf6..91b51ba380 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -1,14 +1,14 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Internal; using Moq; using Xunit; -using Microsoft.AspNetCore.SignalR.Tests.Common; namespace Microsoft.AspNetCore.SignalR.Tests { @@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var trackDispose = new TrackDispose(); var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose)); - var endPoint = serviceProvider.GetService>(); + var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { @@ -127,10 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(MethodHub.TaskValueMethod)).OrTimeout(); + var result = (await client.InvokeAsync(nameof(MethodHub.TaskValueMethod)).OrTimeout()).Result; // json serializer makes this a long - Assert.Equal(42L, result.Result); + Assert.Equal(42L, result); // kill the connection client.Dispose(); @@ -150,10 +150,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke("echo", "hello").OrTimeout(); + var result = (await client.InvokeAsync("echo", "hello").OrTimeout()).Result; - Assert.Null(result.Error); - Assert.Equal("hello", result.Result); + Assert.Equal("hello", result); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + + [Theory] + [InlineData(nameof(MethodHub.MethodThatThrows))] + [InlineData(nameof(MethodHub.MethodThatYieldsFailedTask))] + public async Task HubMethodCanThrowOrYieldFailedTask(string methodName) + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient(serviceProvider)) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + var result = (await client.InvokeAsync(methodName).OrTimeout()); + + Assert.Equal("BOOM!", result.Error); // kill the connection client.Dispose(); @@ -173,10 +196,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(MethodHub.ValueMethod)).OrTimeout(); + var result = (await client.InvokeAsync(nameof(MethodHub.ValueMethod)).OrTimeout()).Result; // json serializer makes this a long - Assert.Equal(43L, result.Result); + Assert.Equal(43L, result); // kill the connection client.Dispose(); @@ -196,9 +219,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(MethodHub.VoidMethod)).OrTimeout(); + var result = (await client.InvokeAsync(nameof(MethodHub.VoidMethod)).OrTimeout()).Result; - Assert.Null(result.Result); + Assert.Null(result); // kill the connection client.Dispose(); @@ -218,9 +241,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout(); + var result = (await client.InvokeAsync(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout()).Result; - Assert.Equal("32, 42, m, string", result.Result); + Assert.Equal("32, 42, m, string", result); // kill the connection client.Dispose(); @@ -240,9 +263,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(InheritedHub.BaseMethod), "string").OrTimeout(); + var result = (await client.InvokeAsync(nameof(InheritedHub.BaseMethod), "string").OrTimeout()).Result; - Assert.Equal("string", result.Result); + Assert.Equal("string", result); // kill the connection client.Dispose(); @@ -262,9 +285,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(InheritedHub.VirtualMethod), 10).OrTimeout(); + var result = (await client.InvokeAsync(nameof(InheritedHub.VirtualMethod), 10).OrTimeout()).Result; - Assert.Equal(0L, result.Result); + Assert.Equal(0L, result); // kill the connection client.Dispose(); @@ -284,7 +307,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - var result = await client.Invoke(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout(); + var result = await client.InvokeAsync(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout(); Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error); @@ -326,15 +349,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); - await firstClient.Invoke(nameof(MethodHub.BroadcastMethod), "test").OrTimeout(); + await firstClient.SendInvocationAsync(nameof(MethodHub.BroadcastMethod), "test").OrTimeout(); foreach (var result in await Task.WhenAll( - firstClient.Read(), - secondClient.Read()).OrTimeout()) + firstClient.Read(), + secondClient.Read()).OrTimeout()) { - Assert.Equal("Broadcast", result.Method); - Assert.Equal(1, result.Arguments.Length); - Assert.Equal("test", result.Arguments[0]); + var invocation = Assert.IsType(result); + Assert.Equal("Broadcast", invocation.Target); + Assert.Equal(1, invocation.Arguments.Length); + Assert.Equal("test", invocation.Arguments[0]); } // kill the connections @@ -360,23 +384,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); - var result = await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout(); + var result = (await firstClient.InvokeAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout()).Result; + // check that 'firstConnection' hasn't received the group send - Assert.Null(result.Id); + Assert.Null(firstClient.TryRead()); // check that 'secondConnection' hasn't received the group send - Assert.Null(await secondClient.TryRead().OrTimeout()); + Assert.Null(secondClient.TryRead()); - result = await secondClient.Invoke(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout(); - Assert.Null(result.Id); + result = (await secondClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout()).Result; - await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout(); + await firstClient.SendInvocationAsync(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout(); // check that 'secondConnection' has received the group send - var descriptor = await secondClient.Read().OrTimeout(); - Assert.Equal("Send", descriptor.Method); - Assert.Equal(1, descriptor.Arguments.Length); - Assert.Equal("test", descriptor.Arguments[0]); + var hubMessage = await secondClient.Read().OrTimeout(); + var invocation = Assert.IsType(hubMessage); + Assert.Equal("Send", invocation.Target); + Assert.Equal(1, invocation.Arguments.Length); + Assert.Equal("test", invocation.Arguments[0]); // kill the connections firstClient.Dispose(); @@ -397,7 +422,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(client.Connection); - await client.Invoke(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout(); + await client.SendInvocationAsync(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout(); // kill the connection client.Dispose(); @@ -421,13 +446,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); - await firstClient.Invoke(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout(); + await firstClient.SendInvocationAsync(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout(); // check that 'secondConnection' has received the group send - var result = await secondClient.Read().OrTimeout(); - Assert.Equal("Send", result.Method); - Assert.Equal(1, result.Arguments.Length); - Assert.Equal("test", result.Arguments[0]); + var hubMessage = await secondClient.Read().OrTimeout(); + var invocation = Assert.IsType(hubMessage); + Assert.Equal("Send", invocation.Target); + Assert.Equal(1, invocation.Arguments.Length); + Assert.Equal("test", invocation.Arguments[0]); // kill the connections firstClient.Dispose(); @@ -452,13 +478,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); - await firstClient.Invoke(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout(); + await firstClient.SendInvocationAsync(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout(); // check that 'secondConnection' has received the group send - var result = await secondClient.Read().OrTimeout(); - Assert.Equal("Send", result.Method); - Assert.Equal(1, result.Arguments.Length); - Assert.Equal("test", result.Arguments[0]); + var hubMessage = await secondClient.Read().OrTimeout(); + var invocation = Assert.IsType(hubMessage); + Assert.Equal("Send", invocation.Target); + Assert.Equal(1, invocation.Arguments.Length); + Assert.Equal("test", invocation.Arguments[0]); // kill the connections firstClient.Dispose(); @@ -575,7 +602,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests public override Task OnDisconnectedAsync(Exception e) { - return TaskCache.CompletedTask; + return Task.CompletedTask; + } + + public void MethodThatThrows() + { + throw new InvalidOperationException("BOOM!"); + } + + public Task MethodThatYieldsFailedTask() + { + return Task.FromException(new InvalidOperationException("BOOM!")); } } @@ -611,11 +648,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - private class TestHub : Hub + private class DisposeTrackingHub : Hub { private TrackDispose _trackDispose; - public TestHub(TrackDispose trackDispose) + public DisposeTrackingHub(TrackDispose trackDispose) { _trackDispose = trackDispose; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 85b9b3ec36..1139ff5295 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -2,29 +2,30 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO; +using System.Collections.Generic; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Internal; -using Microsoft.Extensions.DependencyInjection; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Tests { - public class TestClient : IDisposable + public class TestClient : IDisposable, IInvocationBinder { private static int _id; - private IInvocationAdapter _adapter; + private IHubProtocol _protocol; private CancellationTokenSource _cts; - private TestBinder _binder; public Connection Connection; public IChannelConnection Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; - public TestClient(IServiceProvider serviceProvider, string format = "json") + public TestClient(IServiceProvider serviceProvider) { var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); @@ -33,62 +34,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); Connection = new Connection(Guid.NewGuid().ToString(), transport); - Connection.Metadata["formatType"] = format; Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); - var invocationAdapter = serviceProvider.GetService(); - _adapter = invocationAdapter.GetInvocationAdapter(format); - - _binder = new TestBinder(); + _protocol = new JsonHubProtocol(new JsonSerializer()); _cts = new CancellationTokenSource(); } - public async Task Invoke(string methodName, params object[] args) where T : InvocationMessage + public async Task InvokeAsync(string methodName, params object[] args) { - await Invoke(methodName, args); + var invocationId = await SendInvocationAsync(methodName, args); - return await Read(); - } - - public async Task Invoke(string methodName, params object[] args) - { - var stream = new MemoryStream(); - await _adapter.WriteMessageAsync(new InvocationDescriptor + while (true) { - Arguments = args, - Method = methodName - }, - stream); + var message = await Read(); - await Application.Output.WriteAsync(new Message(stream.ToArray(), MessageType.Binary, endOfMessage: true)); - } - - public async Task Read() where T : InvocationMessage - { - while (await Application.Input.WaitToReadAsync(_cts.Token)) - { - var value = await TryRead(); - - if (value != null) + if (!string.Equals(message.InvocationId, invocationId)) { - return value; + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + + if (message == null) + { + throw new InvalidOperationException("Connection aborted!"); + } + + switch (message) + { + case StreamItemMessage result: + throw new NotSupportedException("TestClient does not support streaming!"); + case CompletionMessage completion: + return completion; + default: + throw new NotSupportedException("TestClient does not support receiving invocations!"); } } - - return null; } - public async Task TryRead() where T : InvocationMessage + public async Task SendInvocationAsync(string methodName, params object[] args) { - Message message; - if (Application.Input.TryRead(out message)) - { - var value = await _adapter.ReadMessageAsync(new MemoryStream(message.Payload), _binder); - return value as T; - } + var invocationId = GetInvocationId(); + var payload = await _protocol.WriteToArrayAsync(new InvocationMessage(invocationId, nonBlocking: false, target: methodName, arguments: args)); + await Application.Output.WriteAsync(new Message(payload, _protocol.MessageType, endOfMessage: true)); + + return invocationId; + } + + public async Task Read() + { + while (true) + { + var message = TryRead(); + + if (message == null) + { + if (!await Application.Input.WaitToReadAsync()) + { + return null; + } + } + else + { + return message; + } + } + } + + public HubMessage TryRead() + { + if (Application.Input.TryRead(out var message)) + { + return _protocol.ParseMessage(message.Payload, this); + } return null; } @@ -98,18 +117,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests Connection.Dispose(); } - private class TestBinder : IInvocationBinder + private static string GetInvocationId() { - public Type[] GetParameterTypes(string methodName) - { - // TODO: Possibly support actual client methods - return new[] { typeof(object) }; - } + return Guid.NewGuid().ToString("N"); + } - public Type GetReturnType(string invocationId) - { - return typeof(object); - } + Type[] IInvocationBinder.GetParameterTypes(string methodName) + { + // TODO: Possibly support actual client methods + return new[] { typeof(object) }; + } + + Type IInvocationBinder.GetReturnType(string invocationId) + { + return typeof(object); } } }