From 7f86b92f7e212ab37ff56524dff660810d3be184 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Thu, 29 Mar 2018 17:50:45 +1300 Subject: [PATCH] Handshake and negotiation optimization (#1731) --- .../HubConnectionContextBenchmark.cs | 8 +- src/Common/ForceAsyncAwaiter.cs | 6 +- .../Internal/Protocol => Common}/JsonUtils.cs | 65 +++++- .../HubConnection.cs | 2 +- .../Internal/Protocol/HandshakeProtocol.cs | 109 +++++++--- .../Protocol/HandshakeResponseMessage.cs | 1 + .../Internal/Protocol/JsonArrayPool.cs | 31 --- .../Internal/Protocol/JsonHubProtocol.cs | 8 +- ...Microsoft.AspNetCore.SignalR.Common.csproj | 4 + .../Microsoft.AspNetCore.SignalR.Redis.csproj | 4 + .../RedisHubLifetimeManager.cs | 5 +- .../HttpConnection.cs | 37 +--- .../Internal/AvailableTransport.cs | 13 ++ .../Internal/NegotiateProtocol.cs | 196 ++++++++++++++++++ .../Internal/NegotiationResponse.cs | 13 ++ ...soft.AspNetCore.Sockets.Common.Http.csproj | 6 + .../HttpConnectionDispatcher.cs | 95 ++++----- .../HttpConnectionTests.Negotiate.cs | 5 +- .../Protocol/HandshakeProtocolTests.cs | 10 +- .../Internal/Protocol/JsonHubProtocolTests.cs | 4 +- .../NegotiateProtocolTests.cs | 51 +++++ 21 files changed, 498 insertions(+), 175 deletions(-) rename src/{Microsoft.AspNetCore.SignalR.Common/Internal/Protocol => Common}/JsonUtils.cs (63%) delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonArrayPool.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/AvailableTransport.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiateProtocol.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiationResponse.cs create mode 100644 test/Microsoft.AspNetCore.Sockets.Tests/NegotiateProtocolTests.cs diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs index 0108e3557c..55c5e77c6a 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs @@ -31,14 +31,14 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks private TestUserIdProvider _userIdProvider; private List _supportedProtocols; private TestDuplexPipe _pipe; - private ReadResult _handshakeResponseResult; + private ReadResult _handshakeRequestResult; [GlobalSetup] public void GlobalSetup() { var memoryBufferWriter = new MemoryBufferWriter(); HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); - _handshakeResponseResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); + _handshakeRequestResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); _pipe = new TestDuplexPipe(); @@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Benchmark] public async Task SuccessHandshakeAsync() { - _pipe.AddReadResult(_handshakeResponseResult); + _pipe.AddReadResult(_handshakeRequestResult); await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _successHubProtocolResolver, _userIdProvider); } @@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Benchmark] public async Task ErrorHandshakeAsync() { - _pipe.AddReadResult(_handshakeResponseResult); + _pipe.AddReadResult(_handshakeRequestResult); await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _failureHubProtocolResolver, _userIdProvider); } diff --git a/src/Common/ForceAsyncAwaiter.cs b/src/Common/ForceAsyncAwaiter.cs index b52aa8a9a6..4515637e59 100644 --- a/src/Common/ForceAsyncAwaiter.cs +++ b/src/Common/ForceAsyncAwaiter.cs @@ -7,7 +7,7 @@ using System.Threading.Tasks; namespace Microsoft.AspNetCore.Sockets.Internal { - public static class ForceAsyncTaskExtensions + internal static class ForceAsyncTaskExtensions { /// /// Returns an awaitable/awaiter that will ensure the continuation is executed @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal } } - public struct ForceAsyncAwaiter : ICriticalNotifyCompletion + internal struct ForceAsyncAwaiter : ICriticalNotifyCompletion { private readonly Task _task; @@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal } } - public struct ForceAsyncAwaiter : ICriticalNotifyCompletion + internal struct ForceAsyncAwaiter : ICriticalNotifyCompletion { private readonly Task _task; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs b/src/Common/JsonUtils.cs similarity index 63% rename from src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs rename to src/Common/JsonUtils.cs index 37ce8d8d85..98fb2199b4 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonUtils.cs +++ b/src/Common/JsonUtils.cs @@ -2,25 +2,40 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.IO; using Newtonsoft.Json; using Newtonsoft.Json.Linq; -namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +namespace Microsoft.AspNetCore.SignalR.Internal { - public static class JsonUtils + internal static class JsonUtils { - internal static JsonTextReader CreateJsonTextReader(Utf8BufferTextReader textReader) + internal static JsonTextReader CreateJsonTextReader(TextReader textReader) { var reader = new JsonTextReader(textReader); reader.ArrayPool = JsonArrayPool.Shared; - // Don't close the output, Utf8BufferTextReader is resettable + // Don't close the input, leave closing to the caller reader.CloseInput = false; return reader; } + internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter) + { + var writer = new JsonTextWriter(textWriter); + + // Don't close the output, leave closing to the caller + writer.CloseOutput = false; + + // SignalR will always write a complete JSON response + // This setting will prevent an error during writing be hidden by another error writing on dispose + writer.AutoCompleteOnClose = false; + + return writer; + } + public static JObject GetObject(JToken token) { if (token == null || token.Type != JTokenType.Object) @@ -82,6 +97,22 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return tokenType.ToString(); } + public static void EnsureObjectStart(JsonTextReader reader) + { + if (reader.TokenType != JsonToken.StartObject) + { + throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Object."); + } + } + + public static void EnsureArrayStart(JsonTextReader reader) + { + if (reader.TokenType != JsonToken.StartArray) + { + throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Array."); + } + } + public static int? ReadAsInt32(JsonTextReader reader, string propertyName) { reader.Read(); @@ -115,10 +146,32 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { if (!reader.Read()) { - throw new JsonReaderException("Unexpected end when reading JSON"); + throw new InvalidDataException("Unexpected end when reading JSON."); } return true; } + + private class JsonArrayPool : IArrayPool + { + private readonly ArrayPool _inner; + + internal static readonly JsonArrayPool Shared = new JsonArrayPool(ArrayPool.Shared); + + public JsonArrayPool(ArrayPool inner) + { + _inner = inner; + } + + public T[] Rent(int minimumLength) + { + return _inner.Rent(minimumLength); + } + + public void Return(T[] array) + { + _inner.Return(array); + } + } } -} +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index def78d57d4..6d21f1729b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -523,7 +523,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) { - if (!string.IsNullOrEmpty(message.Error)) + if (message.Error != null) { Log.HandshakeServerError(_logger, message.Error); throw new HubException( diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs index cdd61a3aec..de181f4603 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public static class HandshakeProtocol { private const string ProtocolPropertyName = "protocol"; - private const string ProtocolVersionName = "version"; + private const string ProtocolVersionPropertyName = "version"; private const string ErrorPropertyName = "error"; private const string TypePropertyName = "type"; @@ -23,12 +23,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol var textWriter = Utf8BufferTextWriter.Get(output); try { - using (var writer = CreateJsonTextWriter(textWriter)) + using (var writer = JsonUtils.CreateJsonTextWriter(textWriter)) { writer.WriteStartObject(); writer.WritePropertyName(ProtocolPropertyName); writer.WriteValue(requestMessage.Protocol); - writer.WritePropertyName(ProtocolVersionName); + writer.WritePropertyName(ProtocolVersionPropertyName); writer.WriteValue(requestMessage.Version); writer.WriteEndObject(); writer.Flush(); @@ -47,7 +47,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol var textWriter = Utf8BufferTextWriter.Get(output); try { - using (var writer = CreateJsonTextWriter(textWriter)) + using (var writer = JsonUtils.CreateJsonTextWriter(textWriter)) { writer.WriteStartObject(); if (!string.IsNullOrEmpty(responseMessage.Error)) @@ -68,14 +68,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol TextMessageFormatter.WriteRecordSeparator(output); } - private static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter) - { - var writer = new JsonTextWriter(textWriter); - writer.CloseOutput = false; - - return writer; - } - public static bool TryParseResponseMessage(ref ReadOnlySequence buffer, out HandshakeResponseMessage responseMessage) { if (!TextMessageParser.TryParseMessage(ref buffer, out var payload)) @@ -90,19 +82,42 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { using (var reader = JsonUtils.CreateJsonTextReader(textReader)) { - var token = JToken.ReadFrom(reader); - var handshakeJObject = JsonUtils.GetObject(token); + JsonUtils.CheckRead(reader); + JsonUtils.EnsureObjectStart(reader); - // a handshake response does not have a type - // check the incoming message was not any other type of message - var type = JsonUtils.GetOptionalProperty(handshakeJObject, TypePropertyName); - if (!string.IsNullOrEmpty(type)) + string error = null; + + var completed = false; + while (!completed && JsonUtils.CheckRead(reader)) { - throw new InvalidOperationException("Handshake response should not have a 'type' value."); - } + switch (reader.TokenType) + { + case JsonToken.PropertyName: + string memberName = reader.Value.ToString(); - var error = JsonUtils.GetOptionalProperty(handshakeJObject, ErrorPropertyName); - responseMessage = new HandshakeResponseMessage(error); + switch (memberName) + { + case TypePropertyName: + // a handshake response does not have a type + // check the incoming message was not any other type of message + throw new InvalidDataException("Handshake response should not have a 'type' value."); + case ErrorPropertyName: + error = JsonUtils.ReadAsString(reader, ErrorPropertyName); + break; + default: + reader.Skip(); + break; + } + break; + case JsonToken.EndObject: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON."); + } + }; + + responseMessage = (error != null) ? new HandshakeResponseMessage(error) : HandshakeResponseMessage.Empty; return true; } } @@ -125,11 +140,51 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { using (var reader = JsonUtils.CreateJsonTextReader(textReader)) { - var token = JToken.ReadFrom(reader); - var handshakeJObject = JsonUtils.GetObject(token); - var protocol = JsonUtils.GetRequiredProperty(handshakeJObject, ProtocolPropertyName); - var protocolVersion = JsonUtils.GetRequiredProperty(handshakeJObject, ProtocolVersionName, JTokenType.Integer); - requestMessage = new HandshakeRequestMessage(protocol, protocolVersion); + JsonUtils.CheckRead(reader); + JsonUtils.EnsureObjectStart(reader); + + string protocol = null; + int? protocolVersion = null; + + var completed = false; + while (!completed && JsonUtils.CheckRead(reader)) + { + switch (reader.TokenType) + { + case JsonToken.PropertyName: + string memberName = reader.Value.ToString(); + + switch (memberName) + { + case ProtocolPropertyName: + protocol = JsonUtils.ReadAsString(reader, ProtocolPropertyName); + break; + case ProtocolVersionPropertyName: + protocolVersion = JsonUtils.ReadAsInt32(reader, ProtocolVersionPropertyName); + break; + default: + reader.Skip(); + break; + } + break; + case JsonToken.EndObject: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON."); + } + } + + if (protocol == null) + { + throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'."); + } + if (protocolVersion == null) + { + throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'."); + } + + requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value); } } finally diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeResponseMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeResponseMessage.cs index 6a02f0bb37..08f7c8ed05 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeResponseMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeResponseMessage.cs @@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public HandshakeResponseMessage(string error) { + // Note that a response with an empty string for error in the JSON is considered an errored response Error = error; } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonArrayPool.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonArrayPool.cs deleted file mode 100644 index 5b4ae0e69e..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonArrayPool.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Buffers; -using Newtonsoft.Json; - -namespace Microsoft.AspNetCore.SignalR.Internal.Protocol -{ - internal class JsonArrayPool : IArrayPool - { - private readonly ArrayPool _inner; - - internal static readonly JsonArrayPool Shared = new JsonArrayPool(ArrayPool.Shared); - - public JsonArrayPool(ArrayPool inner) - { - _inner = inner; - } - - public T[] Rent(int minimumLength) - { - return _inner.Rent(minimumLength); - } - - public void Return(T[] array) - { - _inner.Return(array); - } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index 0e0dfb15e3..f4e5df07b8 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -113,10 +113,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol JsonUtils.CheckRead(reader); // We're always parsing a JSON object - if (reader.TokenType != JsonToken.StartObject) - { - throw new InvalidDataException($"Unexpected JSON Token Type '{JsonUtils.GetTokenString(reader.TokenType)}'. Expected a JSON Object."); - } + JsonUtils.EnsureObjectStart(reader); do { @@ -344,7 +341,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private void WriteMessageCore(HubMessage message, Stream stream) { - using (var writer = new JsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true))) + using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true))) { writer.WriteStartObject(); switch (message) @@ -385,6 +382,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } writer.WriteEndObject(); + writer.Flush(); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj index 4bcebd9980..440fc39440 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj @@ -7,6 +7,10 @@ true + + + + diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj index 460371c6d7..e8b3c9f741 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj @@ -5,6 +5,10 @@ netstandard2.0 + + + + diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index f3e5434b7c..d8c8cce8ba 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -9,6 +9,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Redis.Internal; using Microsoft.Extensions.Logging; @@ -214,10 +215,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis { byte[] payload; using (var stream = new LimitArrayPoolWriteStream()) - using (var writer = new JsonTextWriter(new StreamWriter(stream))) + using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream))) { _serializer.Serialize(writer, message); - await writer.FlushAsync(); + writer.Flush(); payload = stream.ToArray(); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 4a96c840da..1b94e02235 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -299,7 +299,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http using (var response = await httpClient.SendAsync(request)) { response.EnsureSuccessStatusCode(); - var negotiateResponse = await ParseNegotiateResponse(response); + var negotiateResponse = NegotiateProtocol.ParseResponse(await response.Content.ReadAsStreamAsync()); Log.ConnectionEstablished(_logger, negotiateResponse.ConnectionId); return negotiateResponse; } @@ -312,29 +312,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http } } - private static async Task ParseNegotiateResponse(HttpResponseMessage response) - { - NegotiationResponse negotiationResponse; - using (var reader = new JsonTextReader(new StreamReader(await response.Content.ReadAsStreamAsync()))) - { - try - { - negotiationResponse = new JsonSerializer().Deserialize(reader); - } - catch (Exception ex) - { - throw new FormatException("Invalid negotiation response received.", ex); - } - } - - if (negotiationResponse == null) - { - throw new FormatException("Invalid negotiation response received."); - } - - return negotiationResponse; - } - private static Uri CreateConnectUrl(Uri url, string connectionId) { if (string.IsNullOrWhiteSpace(connectionId)) @@ -454,17 +431,5 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http _logScope.ConnectionId = _connectionId; return negotiationResponse; } - - private class NegotiationResponse - { - public string ConnectionId { get; set; } - public AvailableTransport[] AvailableTransports { get; set; } - } - - private class AvailableTransport - { - public string Transport { get; set; } - public string[] TransferFormats { get; set; } - } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/AvailableTransport.cs b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/AvailableTransport.cs new file mode 100644 index 0000000000..055688d1e8 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/AvailableTransport.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets.Internal +{ + public class AvailableTransport + { + public string Transport { get; set; } + public List TransferFormats { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiateProtocol.cs b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiateProtocol.cs new file mode 100644 index 0000000000..87f6a337ab --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiateProtocol.cs @@ -0,0 +1,196 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.Sockets.Internal +{ + public static class NegotiateProtocol + { + private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + + private const string ConnectionIdPropertyName = "connectionId"; + private const string AvailableTransportsPropertyName = "availableTransports"; + private const string TransportPropertyName = "transport"; + private const string TransferFormatsPropertyName = "transferFormats"; + + public static void WriteResponse(NegotiationResponse response, Stream output) + { + using (var jsonWriter = JsonUtils.CreateJsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true))) + { + jsonWriter.WriteStartObject(); + jsonWriter.WritePropertyName(ConnectionIdPropertyName); + jsonWriter.WriteValue(response.ConnectionId); + jsonWriter.WritePropertyName(AvailableTransportsPropertyName); + jsonWriter.WriteStartArray(); + + foreach (var availableTransport in response.AvailableTransports) + { + jsonWriter.WriteStartObject(); + jsonWriter.WritePropertyName(TransportPropertyName); + jsonWriter.WriteValue(availableTransport.Transport); + jsonWriter.WritePropertyName(TransferFormatsPropertyName); + jsonWriter.WriteStartArray(); + + foreach (var transferFormat in availableTransport.TransferFormats) + { + jsonWriter.WriteValue(transferFormat); + } + + jsonWriter.WriteEndArray(); + jsonWriter.WriteEndObject(); + } + + jsonWriter.WriteEndArray(); + jsonWriter.WriteEndObject(); + + jsonWriter.Flush(); + } + } + + public static NegotiationResponse ParseResponse(Stream content) + { + try + { + using (var reader = JsonUtils.CreateJsonTextReader(new StreamReader(content))) + { + JsonUtils.CheckRead(reader); + JsonUtils.EnsureObjectStart(reader); + + string connectionId = null; + List availableTransports = null; + + var completed = false; + while (!completed && JsonUtils.CheckRead(reader)) + { + switch (reader.TokenType) + { + case JsonToken.PropertyName: + var memberName = reader.Value.ToString(); + + switch (memberName) + { + case ConnectionIdPropertyName: + connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName); + break; + case AvailableTransportsPropertyName: + JsonUtils.CheckRead(reader); + JsonUtils.EnsureArrayStart(reader); + + availableTransports = new List(); + while (JsonUtils.CheckRead(reader)) + { + if (reader.TokenType == JsonToken.StartObject) + { + availableTransports.Add(ParseAvailableTransport(reader)); + } + else if (reader.TokenType == JsonToken.EndArray) + { + break; + } + } + break; + default: + reader.Skip(); + break; + } + break; + case JsonToken.EndObject: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON."); + } + } + + if (connectionId == null) + { + throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); + } + + if (availableTransports == null) + { + throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); + } + + return new NegotiationResponse + { + ConnectionId = connectionId, + AvailableTransports = availableTransports + }; + } + } + catch (Exception ex) + { + throw new InvalidDataException("Invalid negotiation response received.", ex); + } + } + + private static AvailableTransport ParseAvailableTransport(JsonTextReader reader) + { + var availableTransport = new AvailableTransport(); + + while (JsonUtils.CheckRead(reader)) + { + switch (reader.TokenType) + { + case JsonToken.PropertyName: + string memberName = reader.Value.ToString(); + + switch (memberName) + { + case TransportPropertyName: + availableTransport.Transport = JsonUtils.ReadAsString(reader, TransportPropertyName); + break; + case TransferFormatsPropertyName: + JsonUtils.CheckRead(reader); + JsonUtils.EnsureArrayStart(reader); + + bool completed = false; + availableTransport.TransferFormats = new List(); + while (!completed && JsonUtils.CheckRead(reader)) + { + switch (reader.TokenType) + { + case JsonToken.String: + availableTransport.TransferFormats.Add(reader.Value.ToString()); + break; + case JsonToken.EndArray: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON."); + } + } + break; + default: + reader.Skip(); + break; + } + break; + case JsonToken.EndObject: + if (availableTransport.Transport == null) + { + throw new InvalidDataException($"Missing required property '{TransportPropertyName}'."); + } + + if (availableTransport.TransferFormats == null) + { + throw new InvalidDataException($"Missing required property '{TransferFormatsPropertyName}'."); + } + + return availableTransport; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading available transport JSON."); + } + } + + throw new InvalidDataException("Unexpected end when reading JSON."); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiationResponse.cs b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiationResponse.cs new file mode 100644 index 0000000000..11552c2e70 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common.Http/Internal/NegotiationResponse.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets.Internal +{ + public class NegotiationResponse + { + public string ConnectionId { get; set; } + public List AvailableTransports { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common.Http/Microsoft.AspNetCore.Sockets.Common.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Common.Http/Microsoft.AspNetCore.Sockets.Common.Http.csproj index f63bdfe987..19b4144274 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common.Http/Microsoft.AspNetCore.Sockets.Common.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Common.Http/Microsoft.AspNetCore.Sockets.Common.Http.csproj @@ -5,7 +5,13 @@ netstandard2.0 Microsoft.AspNetCore.Sockets + + + + + + diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 4c3d1f7ae1..e22a4a0367 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -24,6 +24,27 @@ namespace Microsoft.AspNetCore.Sockets { public partial class HttpConnectionDispatcher { + private static readonly AvailableTransport _webSocketAvailableTransport = + new AvailableTransport + { + Transport = nameof(TransportType.WebSockets), + TransferFormats = new List { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) } + }; + + private static readonly AvailableTransport _serverSentEventsAvailableTransport = + new AvailableTransport + { + Transport = nameof(TransportType.ServerSentEvents), + TransferFormats = new List { nameof(TransferFormat.Text) } + }; + + private static readonly AvailableTransport _longPollingAvailableTransport = + new AvailableTransport + { + Transport = nameof(TransportType.LongPolling), + TransferFormats = new List { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) } + }; + private readonly HttpConnectionManager _manager; private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; @@ -368,7 +389,7 @@ namespace Microsoft.AspNetCore.Sockets logScope.ConnectionId = connection.ConnectionId; // Get the bytes for the connection id - var negotiateResponseBuffer = Encoding.UTF8.GetBytes(GetNegotiatePayload(connection.ConnectionId, context, options)); + var negotiateResponseBuffer = GetNegotiatePayload(connection.ConnectionId, context, options); Log.NegotiationRequest(_logger); @@ -377,40 +398,31 @@ namespace Microsoft.AspNetCore.Sockets return context.Response.Body.WriteAsync(negotiateResponseBuffer, 0, negotiateResponseBuffer.Length); } - private static string GetNegotiatePayload(string connectionId, HttpContext context, HttpConnectionOptions options) + private static byte[] GetNegotiatePayload(string connectionId, HttpContext context, HttpConnectionOptions options) { - var sb = new StringBuilder(); - using (var jsonWriter = new JsonTextWriter(new StringWriter(sb))) + NegotiationResponse response = new NegotiationResponse(); + response.ConnectionId = connectionId; + response.AvailableTransports = new List(); + + if ((options.Transports & TransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features)) { - jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName("connectionId"); - jsonWriter.WriteValue(connectionId); - jsonWriter.WritePropertyName("availableTransports"); - jsonWriter.WriteStartArray(); - - if (ServerHasWebSockets(context.Features)) - { - if ((options.Transports & TransportType.WebSockets) != 0) - { - WriteTransport(jsonWriter, nameof(TransportType.WebSockets), TransferFormat.Text | TransferFormat.Binary); - } - } - - if ((options.Transports & TransportType.ServerSentEvents) != 0) - { - WriteTransport(jsonWriter, nameof(TransportType.ServerSentEvents), TransferFormat.Text); - } - - if ((options.Transports & TransportType.LongPolling) != 0) - { - WriteTransport(jsonWriter, nameof(TransportType.LongPolling), TransferFormat.Text | TransferFormat.Binary); - } - - jsonWriter.WriteEndArray(); - jsonWriter.WriteEndObject(); + response.AvailableTransports.Add(_webSocketAvailableTransport); } - return sb.ToString(); + if ((options.Transports & TransportType.ServerSentEvents) != 0) + { + response.AvailableTransports.Add(_serverSentEventsAvailableTransport); + } + + if ((options.Transports & TransportType.LongPolling) != 0) + { + response.AvailableTransports.Add(_longPollingAvailableTransport); + } + + MemoryStream ms = new MemoryStream(); + NegotiateProtocol.WriteResponse(response, ms); + + return ms.ToArray(); } private static bool ServerHasWebSockets(IFeatureCollection features) @@ -418,27 +430,6 @@ namespace Microsoft.AspNetCore.Sockets return features.Get() != null; } - private static void WriteTransport(JsonWriter writer, string transportName, TransferFormat supportedTransferFormats) - { - writer.WriteStartObject(); - writer.WritePropertyName("transport"); - writer.WriteValue(transportName); - writer.WritePropertyName("transferFormats"); - writer.WriteStartArray(); - if ((supportedTransferFormats & TransferFormat.Binary) != 0) - { - writer.WriteValue(nameof(TransferFormat.Binary)); - } - - if ((supportedTransferFormats & TransferFormat.Text) != 0) - { - writer.WriteValue(nameof(TransferFormat.Text)); - } - - writer.WriteEndArray(); - writer.WriteEndObject(); - } - private static string GetConnectionId(HttpContext context) => context.Request.Query["id"]; private async Task ProcessSend(HttpContext context, HttpConnectionOptions options) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs index a4da4dc7f0..1f0cf5a36e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.Net; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; @@ -25,13 +26,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData("Not Json")] public Task StartThrowsFormatExceptionIfNegotiationResponseIsInvalid(string negotiatePayload) { - return RunInvalidNegotiateResponseTest(negotiatePayload, "Invalid negotiation response received."); + return RunInvalidNegotiateResponseTest(negotiatePayload, "Invalid negotiation response received."); } [Fact] public Task StartThrowsFormatExceptionIfNegotiationResponseHasNoConnectionId() { - return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(connectionId: null), "Invalid connection id."); + return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(connectionId: string.Empty), "Invalid connection id."); } [Fact] diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/HandshakeProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/HandshakeProtocolTests.cs index dcb35d5839..adbc830a11 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/HandshakeProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/HandshakeProtocolTests.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Theory] [InlineData("{\"protocol\":\"dummy\",\"version\":1}\u001e", "dummy", 1)] [InlineData("{\"protocol\":\"\",\"version\":10}\u001e", "", 10)] - [InlineData("{\"protocol\":null,\"version\":123}\u001e", null, 123)] + [InlineData("{\"protocol\":\"\",\"version\":10,\"unknown\":null}\u001e", "", 10)] public void ParsingHandshakeRequestMessageSuccessForValidMessages(string json, string protocol, int version) { var message = new ReadOnlySequence(Encoding.UTF8.GetBytes(json)); @@ -29,8 +29,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Theory] [InlineData("{\"error\":\"dummy\"}\u001e", "dummy")] [InlineData("{\"error\":\"\"}\u001e", "")] - [InlineData("{\"error\":null}\u001e", null)] [InlineData("{}\u001e", null)] + [InlineData("{\"unknown\":null}\u001e", null)] public void ParsingHandshakeResponseMessageSuccessForValidMessages(string json, string error) { var message = new ReadOnlySequence(Encoding.UTF8.GetBytes(json)); @@ -55,7 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] [InlineData("{\"protocol\":\"json\"}\u001e", "Missing required property 'version'.")] [InlineData("{\"version\":1}\u001e", "Missing required property 'protocol'.")] - [InlineData("{\"protocol\":null,\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")] + [InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")] + [InlineData("{\"protocol\":null,\"version\":123}\u001e", "Expected 'protocol' to be of type String.")] public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage) { var message = new ReadOnlySequence(Encoding.UTF8.GetBytes(payload)); @@ -71,12 +72,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] [InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] [InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] + [InlineData("{\"error\":null}\u001e", "Expected 'error' to be of type String.")] public void ParsingHandshakeResponseMessageThrowsForInvalidMessages(string payload, string expectedMessage) { var message = new ReadOnlySequence(Encoding.UTF8.GetBytes(payload)); var exception = Assert.Throws(() => - HandshakeProtocol.TryParseRequestMessage(ref message, out _)); + HandshakeProtocol.TryParseResponseMessage(ref message, out _)); Assert.Equal(expectedMessage, exception.Message); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 71ceb25eec..8fd29e4c3b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -141,7 +141,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } [Theory] - [InlineData("", "Error reading JSON.")] + [InlineData("", "Unexpected end when reading JSON.")] [InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] [InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] [InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] @@ -177,7 +177,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")] [InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")] - [InlineData("{'type':3,'invocationId':'42','result':true", "Error reading JSON.")] + [InlineData("{'type':3,'invocationId':'42','result':true", "Unexpected end when reading JSON.")] public void InvalidMessages(string input, string expectedMessage) { input = Frame(input); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/NegotiateProtocolTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/NegotiateProtocolTests.cs new file mode 100644 index 0000000000..cb6b92d804 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Tests/NegotiateProtocolTests.cs @@ -0,0 +1,51 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Tests +{ + public class NegotiateProtocolTests + { + [Theory] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0])] + [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0])] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new [] { "test"})] + public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports) + { + var responseData = Encoding.UTF8.GetBytes(json); + var ms = new MemoryStream(responseData); + var response = NegotiateProtocol.ParseResponse(ms); + + Assert.Equal(connectionId, response.ConnectionId); + Assert.Equal(availableTransports.Length, response.AvailableTransports.Count); + + var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList(); + + Assert.Equal(availableTransports, responseTransports); + } + + [Theory] + [InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] + [InlineData("[]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] + [InlineData("{\"availableTransports\":[]}", "Missing required property 'connectionId'.")] + [InlineData("{\"connectionId\":123,\"availableTransports\":[]}", "Expected 'connectionId' to be of type String.")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":null}", "Unexpected JSON Token Type 'Null'. Expected a JSON Array.")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transferFormats\":[]}]}", "Missing required property 'transport'.")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\"}]}", "Missing required property 'transferFormats'.")] + public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage) + { + var responseData = Encoding.UTF8.GetBytes(payload); + var ms = new MemoryStream(responseData); + + var exception = Assert.Throws(() => NegotiateProtocol.ParseResponse(ms)); + + Assert.Equal(expectedMessage, exception.InnerException.Message); + } + } +}