diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index c555343736..6d6624aef3 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -77,6 +78,12 @@ namespace Microsoft.AspNetCore.SignalR.Client public async Task StartAsync() { await _connection.StartAsync(); + + using (var memoryStream = new MemoryStream()) + { + NegotiationProtocol.TryWriteProtocolNegotiationMessage(new NegotiationMessage(_protocol.Name), memoryStream); + await _connection.SendAsync(memoryStream.ToArray(), _connectionActive.Token); + } } public async Task DisposeAsync() diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs index c5b8bae125..9b15d50d6d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs @@ -1,8 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; - namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public abstract class HubMessage diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index 4ba7423321..bd571a20d7 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.IO; @@ -10,6 +9,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public interface IHubProtocol { + string Name { get; } + bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages); bool TryWriteMessage(HubMessage message, Stream output); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index 0a8a2de9df..f5ac835814 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -44,6 +44,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol _payloadSerializer = payloadSerializer; } + public string Name { get => "json"; } + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) { messages = new List(); @@ -80,20 +82,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { // PERF: Could probably use the JsonTextReader directly for better perf and fewer allocations var token = JToken.ReadFrom(reader); - if (token == null) - { - return null; - } - if (token.Type != JTokenType.Object) + if (token == null || token.Type != JTokenType.Object) { - throw new FormatException($"Unexpected JSON Token Type '{token.Type}'. Expected a JSON Object."); + throw new FormatException($"Unexpected JSON Token Type '{token?.Type}'. Expected a JSON Object."); } var json = (JObject)token; // Determine the type of the message - var type = GetRequiredProperty(json, TypePropertyName, JTokenType.Integer); + var type = JsonUtils.GetRequiredProperty(json, TypePropertyName, JTokenType.Integer); switch (type) { case InvocationMessageType: @@ -194,11 +192,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private InvocationMessage BindInvocationMessage(JObject json, IInvocationBinder binder) { - var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); - var target = GetRequiredProperty(json, TargetPropertyName, JTokenType.String); - var nonBlocking = GetOptionalProperty(json, NonBlockingPropertyName, JTokenType.Boolean); + var invocationId = JsonUtils.GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var target = JsonUtils.GetRequiredProperty(json, TargetPropertyName, JTokenType.String); + var nonBlocking = JsonUtils.GetOptionalProperty(json, NonBlockingPropertyName, JTokenType.Boolean); - var args = GetRequiredProperty(json, ArgumentsPropertyName, JTokenType.Array); + var args = JsonUtils.GetRequiredProperty(json, ArgumentsPropertyName, JTokenType.Array); var paramTypes = binder.GetParameterTypes(target); var arguments = new object[args.Count]; @@ -221,8 +219,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private StreamItemMessage BindResultMessage(JObject json, IInvocationBinder binder) { - var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); - var result = GetRequiredProperty(json, ItemPropertyName); + var invocationId = JsonUtils.GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var result = JsonUtils.GetRequiredProperty(json, ItemPropertyName); var returnType = binder.GetReturnType(invocationId); return new StreamItemMessage(invocationId, result?.ToObject(returnType, _payloadSerializer)); @@ -230,8 +228,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private CompletionMessage BindCompletionMessage(JObject json, IInvocationBinder binder) { - var invocationId = GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); - var error = GetOptionalProperty(json, ErrorPropertyName, JTokenType.String); + var invocationId = JsonUtils.GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var error = JsonUtils.GetOptionalProperty(json, ErrorPropertyName, JTokenType.String); var resultProp = json.Property(ResultPropertyName); if (error != null && resultProp != null) @@ -250,38 +248,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return new CompletionMessage(invocationId, error, result: payload, hasResult: true); } } - - private T GetOptionalProperty(JObject json, string property, JTokenType expectedType = JTokenType.None, T defaultValue = default(T)) - { - var prop = json[property]; - - if (prop == null) - { - return defaultValue; - } - - return GetValue(property, expectedType, prop); - } - - private T GetRequiredProperty(JObject json, string property, JTokenType expectedType = JTokenType.None) - { - var prop = json[property]; - - if (prop == null) - { - throw new FormatException($"Missing required property '{property}'."); - } - - return GetValue(property, expectedType, prop); - } - - private static T GetValue(string property, JTokenType expectedType, JToken prop) - { - if (expectedType != JTokenType.None && prop.Type != expectedType) - { - throw new FormatException($"Expected '{property}' to be of type {expectedType}."); - } - return prop.Value(); - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs new file mode 100644 index 0000000000..4e574f7f90 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public static class JsonUtils + { + public static T GetOptionalProperty(JObject json, string property, JTokenType expectedType = JTokenType.None, T defaultValue = default(T)) + { + var prop = json[property]; + + if (prop == null) + { + return defaultValue; + } + + return GetValue(property, expectedType, prop); + } + + public static T GetRequiredProperty(JObject json, string property, JTokenType expectedType = JTokenType.None) + { + var prop = json[property]; + + if (prop == null) + { + throw new FormatException($"Missing required property '{property}'."); + } + + return GetValue(property, expectedType, prop); + } + + public static T GetValue(string property, JTokenType expectedType, JToken prop) + { + if (expectedType != JTokenType.None && prop.Type != expectedType) + { + throw new FormatException($"Expected '{property}' to be of type {expectedType}."); + } + return prop.Value(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index 8ba9abb8bc..0fc4c43228 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -16,6 +16,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int StreamItemMessageType = 2; private const int CompletionMessageType = 3; + public string Name { get => "messagepack"; } + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationMessage.cs new file mode 100644 index 0000000000..c3e21800c2 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationMessage.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class NegotiationMessage + { + public NegotiationMessage(string protocol) + { + Protocol = protocol; + } + + public string Protocol { get; } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs new file mode 100644 index 0000000000..574e5bab4e --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public static class NegotiationProtocol + { + private const string ProtocolPropertyName = "protocol"; + + public static bool TryWriteProtocolNegotiationMessage(NegotiationMessage negotiationMessage, Stream output) + { + using (var memoryStream = new MemoryStream()) + { + using (var writer = new JsonTextWriter(new StreamWriter(memoryStream))) + { + writer.WriteStartObject(); + writer.WritePropertyName(ProtocolPropertyName); + writer.WriteValue(negotiationMessage.Protocol); + writer.WriteEndObject(); + } + + memoryStream.Flush(); + return TextMessageFormatter.TryWriteMessage(new ReadOnlySpan(memoryStream.ToArray()), output); + } + } + + public static bool TryReadProtocolNegotiationMessage(ReadOnlySpan input, out NegotiationMessage negotiationMessage) + { + var parser = new TextMessageParser(); + if (!parser.TryParseMessage(ref input, out var payload)) + { + throw new FormatException("Unable to parse payload as a negotiation message."); + } + + using (var memoryStream = new MemoryStream(payload.ToArray())) + { + using (var reader = new JsonTextReader(new StreamReader(memoryStream))) + { + var token = JToken.ReadFrom(reader); + if (token == null || token.Type != JTokenType.Object) + { + throw new FormatException($"Unexpected JSON Token Type '{token?.Type}'. Expected a JSON Object."); + } + + var negotiationJObject = (JObject)token; + var protocol = JsonUtils.GetRequiredProperty(negotiationJObject, ProtocolPropertyName); + negotiationMessage = new NegotiationMessage(protocol); + } + } + return true; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index b865abdc5c..9e91c1b47e 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -59,13 +59,10 @@ namespace Microsoft.AspNetCore.SignalR public async Task OnConnectedAsync(ConnectionContext connection) { + await ProcessNegotiate(connection); + try { - // Resolve the Hub Protocol for the connection and store it in metadata - // Other components, outside the Hub, may need to know what protocol is in use - // for a particular connection, so we store it here. - connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection); - await _lifetimeManager.OnConnectedAsync(connection); await RunHubAsync(connection); } @@ -75,6 +72,26 @@ namespace Microsoft.AspNetCore.SignalR } } + private async Task ProcessNegotiate(ConnectionContext connection) + { + while (await connection.Transport.In.WaitToReadAsync()) + { + while (connection.Transport.In.TryRead(out var buffer)) + { + if (NegotiationProtocol.TryReadProtocolNegotiationMessage(buffer, out var negotiationMessage)) + { + // Resolve the Hub Protocol for the connection and store it in metadata + // Other components, outside the Hub, may need to know what protocol is in use + // for a particular connection, so we store it here. + connection.Metadata[HubConnectionMetadataNames.HubProtocol] = + _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + + return; + } + } + } + } + private async Task RunHubAsync(ConnectionContext connection) { await HubOnConnectedAsync(connection); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs index c4e8957522..2fdc6202ed 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Newtonsoft.Json; @@ -9,10 +10,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { - public IHubProtocol GetProtocol(ConnectionContext connection) + public IHubProtocol GetProtocol(string protocolName, ConnectionContext connection) { - // TODO: Allow customization of this serializer! - return new JsonHubProtocol(new JsonSerializer()); + switch (protocolName?.ToLowerInvariant()) + { + case "json": + return new JsonHubProtocol(new JsonSerializer()); + case "messagepack": + return new MessagePackHubProtocol(); + default: + throw new NotSupportedException($"The protocol '{protocolName ?? "(null)"}' is not supported."); + } } } } diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs index 9d2c1f1bb8..e797528770 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs @@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { - IHubProtocol GetProtocol(ConnectionContext connection); + IHubProtocol GetProtocol(string protocolName, ConnectionContext connection); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 85323a8829..d01519bed9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -2,11 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Reactive.Linq; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -14,6 +16,7 @@ using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Console; +using Newtonsoft.Json; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests @@ -47,13 +50,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests _testServer = new TestServer(webHostBuilder); } - [Fact] - public async Task CheckFixedMessage() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task CheckFixedMessage(IHubProtocol protocol) { var loggerFactory = CreateLogger(); var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); try { await connection.StartAsync(); @@ -68,14 +72,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task CanSendAndReceiveMessage() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task CanSendAndReceiveMessage(IHubProtocol protocol) { var loggerFactory = CreateLogger(); const string originalMessage = "SignalR"; var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); try { await connection.StartAsync(); @@ -90,14 +95,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task MethodsAreCaseInsensitive() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task MethodsAreCaseInsensitive(IHubProtocol protocol) { var loggerFactory = CreateLogger(); const string originalMessage = "SignalR"; var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); try { await connection.StartAsync(); @@ -112,14 +118,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task CanInvokeClientMethodFromServer() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task CanInvokeClientMethodFromServer(IHubProtocol protocol) { var loggerFactory = CreateLogger(); const string originalMessage = "SignalR"; var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); try { await connection.StartAsync(); @@ -137,13 +144,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task CanStreamClientMethodFromServer() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task CanStreamClientMethodFromServer(IHubProtocol protocol) { var loggerFactory = CreateLogger(); var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); try { await connection.StartAsync(); @@ -160,13 +168,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task ServerClosesConnectionIfHubMethodCannotBeResolved() + [Theory] + [MemberData(nameof(HubProtocols))] + public async Task ServerClosesConnectionIfHubMethodCannotBeResolved(IHubProtocol hubProtocol) { var loggerFactory = CreateLogger(); var httpConnection = new HttpConnection(new Uri("http://test/hubs"), TransportType.LongPolling, loggerFactory, _testServer.CreateHandler()); - var connection = new HubConnection(httpConnection, loggerFactory); + var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); try { await connection.StartAsync(); @@ -182,6 +191,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + public static IEnumerable HubProtocols() => + new[] + { + new object[] { new JsonHubProtocol(new JsonSerializer()) }, + new object[] { new MessagePackHubProtocol() }, + }; + public void Dispose() { _testServer.Dispose(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index ea0d1d0ecd..8f344a424c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -17,6 +17,25 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // don't cause problems. public class HubConnectionProtocolTests { + [Fact] + public async Task ClientSendsNegotationMessageWhenStartingConnection() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + var negotiationMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + Assert.Equal("19:{\"protocol\":\"json\"};", negotiationMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + [Fact] public async Task InvokeSendsAnInvocationMessage() { @@ -28,6 +47,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.Invoke("Foo"); + // skip negotiation + await connection.ReadSentTextMessageAsync().OrTimeout(); var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); Assert.Equal("59:{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]};", invokeMessage); @@ -50,6 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = hubConnection.Stream("Foo"); + // skip negotiation + await connection.ReadSentTextMessageAsync().OrTimeout(); var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); Assert.Equal("59:{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]};", invokeMessage); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index cf15f6688c..db82e86872 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; @@ -174,6 +173,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }; } + public string Name { get => "MockHubProtocol"; } + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 535f605fc3..2d59afac37 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -3,7 +3,6 @@ using System; using System.IO; -using System.Net.Http; using System.Text; using System.Threading; using System.Threading.Tasks; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/NegotiationProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/NegotiationProtocolTests.cs new file mode 100644 index 0000000000..2900f4b4f2 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/NegotiationProtocolTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol +{ + public class NegotiationProtocolTests + { + [Fact] + public void CanRoundtripNegotiation() + { + var negotiationMessage = new NegotiationMessage(protocol: "dummy"); + using (var ms = new MemoryStream()) + { + Assert.True(NegotiationProtocol.TryWriteProtocolNegotiationMessage(negotiationMessage, ms)); + Assert.True(NegotiationProtocol.TryReadProtocolNegotiationMessage(ms.ToArray(), out var deserializedMessage)); + + Assert.NotNull(deserializedMessage); + Assert.Equal(negotiationMessage.Protocol, deserializedMessage.Protocol); + } + } + + [Theory] + [InlineData("2:", "Unable to parse payload as a negotiation message.")] + [InlineData("2:42;", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] + [InlineData("4:\"42\";", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] + [InlineData("4:null;", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] + [InlineData("2:{};", "Missing required property 'protocol'.")] + [InlineData("2:[];", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] + public void ParsingNegotiationMessageThrowsForInvalidMessages(string payload, string expectedMessage) + { + var message = Encoding.UTF8.GetBytes(payload); + + var exception = Assert.Throws(() => + Assert.True(NegotiationProtocol.TryReadProtocolNegotiationMessage(message, out var deserializedMessage))); + + Assert.Equal(expectedMessage, exception.Message); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs new file mode 100644 index 0000000000..e84a5a640d --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; +using Moq; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests +{ + public class DefaultHubProtocolResolverTests + { + [Theory] + [MemberData(nameof(HubProtocols))] + public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) + { + var mockConnection = new Mock(); + Assert.IsType( + protocol.GetType(), + new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object)); + } + + [Theory] + [InlineData(null)] + [InlineData("dummy")] + public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName) + { + var mockConnection = new Mock(); + var exception = Assert.Throws( + () => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object)); + + Assert.Equal($"The protocol '{protocolName ?? "(null)"}' is not supported.", exception.Message); + } + + public static IEnumerable HubProtocols() => + new[] + { + new object[] { new JsonHubProtocol(new JsonSerializer()) }, + new object[] { new MessagePackHubProtocol() }, + }; + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 847b8f3a0e..b6fb1bcc2b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; @@ -41,6 +42,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests _protocol = new JsonHubProtocol(new JsonSerializer()); _cts = new CancellationTokenSource(); + + using (var memoryStream = new MemoryStream()) + { + NegotiationProtocol.TryWriteProtocolNegotiationMessage(new NegotiationMessage(_protocol.Name), memoryStream); + Application.Out.TryWrite(memoryStream.ToArray()); + } } public async Task> StreamAsync(string methodName, params object[] args)