diff --git a/eng/Dependencies.props b/eng/Dependencies.props index 91c0e3c509..c842d044f0 100644 --- a/eng/Dependencies.props +++ b/eng/Dependencies.props @@ -39,6 +39,7 @@ and are generated based on the last package release. + diff --git a/eng/Versions.props b/eng/Versions.props index fe0ceaa390..a689f66784 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -13,6 +13,7 @@ 3.0.0-preview-27324-5 3.0.0-preview-27324-5 + 4.6.0-preview.19073.11 4.6.0-preview.19073.11 4.6.0-preview.19073.11 4.6.0-preview.19073.11 diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 84b9655d10..79312f4bd5 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -439,11 +439,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client using (var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead)) { response.EnsureSuccessStatusCode(); - NegotiationResponse negotiateResponse; - using (var responseStream = await response.Content.ReadAsStreamAsync()) - { - negotiateResponse = NegotiateProtocol.ParseResponse(responseStream); - } + var responseBuffer = await response.Content.ReadAsByteArrayAsync(); + var negotiateResponse = NegotiateProtocol.ParseResponse(responseBuffer); if (!string.IsNullOrEmpty(negotiateResponse.Error)) { throw new Exception(negotiateResponse.Error); diff --git a/src/SignalR/common/Http.Connections.Common/src/Microsoft.AspNetCore.Http.Connections.Common.csproj b/src/SignalR/common/Http.Connections.Common/src/Microsoft.AspNetCore.Http.Connections.Common.csproj index 6b8195785f..a54ac4c17b 100644 --- a/src/SignalR/common/Http.Connections.Common/src/Microsoft.AspNetCore.Http.Connections.Common.csproj +++ b/src/SignalR/common/Http.Connections.Common/src/Microsoft.AspNetCore.Http.Connections.Common.csproj @@ -2,20 +2,24 @@ Common primitives for ASP.NET Connection Handlers and clients - netstandard2.0 + netstandard2.0;netcoreapp3.0 Microsoft.AspNetCore.Http.Connections true true + $(NoWarn);3021 - + - + + All + + diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index 49b3cc4336..d473f5c93c 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -5,180 +5,186 @@ using System; using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Text; +using System.Text.Json; using Microsoft.AspNetCore.Internal; -using Newtonsoft.Json; namespace Microsoft.AspNetCore.Http.Connections { public static class NegotiateProtocol { private const string ConnectionIdPropertyName = "connectionId"; + private static readonly byte[] ConnectionIdPropertyNameBytes = Encoding.UTF8.GetBytes(ConnectionIdPropertyName); private const string UrlPropertyName = "url"; + private static readonly byte[] UrlPropertyNameBytes = Encoding.UTF8.GetBytes(UrlPropertyName); private const string AccessTokenPropertyName = "accessToken"; + private static readonly byte[] AccessTokenPropertyNameBytes = Encoding.UTF8.GetBytes(AccessTokenPropertyName); private const string AvailableTransportsPropertyName = "availableTransports"; + private static readonly byte[] AvailableTransportsPropertyNameBytes = Encoding.UTF8.GetBytes(AvailableTransportsPropertyName); private const string TransportPropertyName = "transport"; + private static readonly byte[] TransportPropertyNameBytes = Encoding.UTF8.GetBytes(TransportPropertyName); private const string TransferFormatsPropertyName = "transferFormats"; + private static readonly byte[] TransferFormatsPropertyNameBytes = Encoding.UTF8.GetBytes(TransferFormatsPropertyName); private const string ErrorPropertyName = "error"; + private static readonly byte[] ErrorPropertyNameBytes = Encoding.UTF8.GetBytes(ErrorPropertyName); + // Used to detect ASP.NET SignalR Server connection attempt private const string ProtocolVersionPropertyName = "ProtocolVersion"; + private static readonly byte[] ProtocolVersionPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolVersionPropertyName); public static void WriteResponse(NegotiationResponse response, IBufferWriter output) { - var textWriter = Utf8BufferTextWriter.Get(output); - try + var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true })); + writer.WriteStartObject(); + + if (!string.IsNullOrEmpty(response.Url)) { - using (var jsonWriter = JsonUtils.CreateJsonTextWriter(textWriter)) + writer.WriteString(UrlPropertyNameBytes, response.Url, escape: false); + } + + if (!string.IsNullOrEmpty(response.AccessToken)) + { + writer.WriteString(AccessTokenPropertyNameBytes, response.AccessToken, escape: false); + } + + if (!string.IsNullOrEmpty(response.ConnectionId)) + { + writer.WriteString(ConnectionIdPropertyNameBytes, response.ConnectionId, escape: false); + } + + writer.WriteStartArray(AvailableTransportsPropertyNameBytes, escape: false); + + if (response.AvailableTransports != null) + { + foreach (var availableTransport in response.AvailableTransports) { - jsonWriter.WriteStartObject(); - - if (!string.IsNullOrEmpty(response.Url)) + writer.WriteStartObject(); + if (availableTransport.Transport != null) { - jsonWriter.WritePropertyName(UrlPropertyName); - jsonWriter.WriteValue(response.Url); + writer.WriteString(TransportPropertyNameBytes, availableTransport.Transport, escape: false); } - - if (!string.IsNullOrEmpty(response.AccessToken)) + else { - jsonWriter.WritePropertyName(AccessTokenPropertyName); - jsonWriter.WriteValue(response.AccessToken); + // Might be able to remove this after https://github.com/dotnet/corefx/issues/34632 is resolved + writer.WriteNull(TransportPropertyNameBytes, escape: false); } + writer.WriteStartArray(TransferFormatsPropertyNameBytes, escape: false); - if (!string.IsNullOrEmpty(response.ConnectionId)) + if (availableTransport.TransferFormats != null) { - jsonWriter.WritePropertyName(ConnectionIdPropertyName); - jsonWriter.WriteValue(response.ConnectionId); - } - - jsonWriter.WritePropertyName(AvailableTransportsPropertyName); - jsonWriter.WriteStartArray(); - - if (response.AvailableTransports != null) - { - foreach (var availableTransport in response.AvailableTransports) + foreach (var transferFormat in availableTransport.TransferFormats) { - jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName(TransportPropertyName); - jsonWriter.WriteValue(availableTransport.Transport); - jsonWriter.WritePropertyName(TransferFormatsPropertyName); - jsonWriter.WriteStartArray(); - - if (availableTransport.TransferFormats != null) - { - foreach (var transferFormat in availableTransport.TransferFormats) - { - jsonWriter.WriteValue(transferFormat); - } - } - - jsonWriter.WriteEndArray(); - jsonWriter.WriteEndObject(); + writer.WriteStringValue(transferFormat, escape: false); } } - jsonWriter.WriteEndArray(); - jsonWriter.WriteEndObject(); - - jsonWriter.Flush(); + writer.WriteEndArray(); + writer.WriteEndObject(); } } - finally - { - Utf8BufferTextWriter.Return(textWriter); - } + + writer.WriteEndArray(); + writer.WriteEndObject(); + + writer.Flush(isFinalBlock: true); } - public static NegotiationResponse ParseResponse(Stream content) + public static NegotiationResponse ParseResponse(ReadOnlySpan content) { try { - using (var reader = JsonUtils.CreateJsonTextReader(new StreamReader(content))) + var reader = new Utf8JsonReader(content, isFinalBlock: true, state: default); + + reader.CheckRead(); + reader.EnsureObjectStart(); + + string connectionId = null; + string url = null; + string accessToken = null; + List availableTransports = null; + string error = null; + + var completed = false; + while (!completed && reader.CheckRead()) { - JsonUtils.CheckRead(reader); - JsonUtils.EnsureObjectStart(reader); - - string connectionId = null; - string url = null; - string accessToken = null; - List availableTransports = null; - string error = null; - - var completed = false; - while (!completed && JsonUtils.CheckRead(reader)) + switch (reader.TokenType) { - switch (reader.TokenType) - { - case JsonToken.PropertyName: - var memberName = reader.Value.ToString(); + case JsonTokenType.PropertyName: + var memberName = reader.ValueSpan; - switch (memberName) + if (memberName.SequenceEqual(UrlPropertyNameBytes)) + { + url = reader.ReadAsString(UrlPropertyNameBytes); + } + else if (memberName.SequenceEqual(AccessTokenPropertyNameBytes)) + { + accessToken = reader.ReadAsString(AccessTokenPropertyNameBytes); + } + else if (memberName.SequenceEqual(ConnectionIdPropertyNameBytes)) + { + connectionId = reader.ReadAsString(ConnectionIdPropertyNameBytes); + } + else if (memberName.SequenceEqual(AvailableTransportsPropertyNameBytes)) + { + reader.CheckRead(); + reader.EnsureArrayStart(); + + availableTransports = new List(); + while (reader.CheckRead()) { - case UrlPropertyName: - url = JsonUtils.ReadAsString(reader, UrlPropertyName); - break; - case AccessTokenPropertyName: - accessToken = JsonUtils.ReadAsString(reader, AccessTokenPropertyName); - break; - case ConnectionIdPropertyName: - connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName); - break; - case AvailableTransportsPropertyName: - JsonUtils.CheckRead(reader); - JsonUtils.EnsureArrayStart(reader); - - availableTransports = new List(); - while (JsonUtils.CheckRead(reader)) - { - if (reader.TokenType == JsonToken.StartObject) - { - availableTransports.Add(ParseAvailableTransport(reader)); - } - else if (reader.TokenType == JsonToken.EndArray) - { - break; - } - } - break; - case ErrorPropertyName: - error = JsonUtils.ReadAsString(reader, ErrorPropertyName); - break; - case ProtocolVersionPropertyName: - throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details."); - default: - reader.Skip(); + if (reader.TokenType == JsonTokenType.StartObject) + { + availableTransports.Add(ParseAvailableTransport(ref reader)); + } + else if (reader.TokenType == JsonTokenType.EndArray) + { break; + } } - break; - case JsonToken.EndObject: - completed = true; - break; - default: - throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON."); - } + } + else if (memberName.SequenceEqual(ErrorPropertyNameBytes)) + { + error = reader.ReadAsString(ErrorPropertyNameBytes); + } + else if (memberName.SequenceEqual(ProtocolVersionPropertyNameBytes)) + { + throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details."); + } + else + { + reader.Skip(); + } + break; + case JsonTokenType.EndObject: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON."); } - - if (url == null && error == null) - { - // if url isn't specified or there isn't an error, connectionId and available transports are required - if (connectionId == null) - { - throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); - } - - if (availableTransports == null) - { - throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); - } - } - - return new NegotiationResponse - { - ConnectionId = connectionId, - Url = url, - AccessToken = accessToken, - AvailableTransports = availableTransports, - Error = error, - }; } + + if (url == null && error == null) + { + // if url isn't specified or there isn't an error, connectionId and available transports are required + if (connectionId == null) + { + throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); + } + + if (availableTransports == null) + { + throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); + } + } + + return new NegotiationResponse + { + ConnectionId = connectionId, + Url = url, + AccessToken = accessToken, + AvailableTransports = availableTransports, + Error = error, + }; } catch (Exception ex) { @@ -186,49 +192,60 @@ namespace Microsoft.AspNetCore.Http.Connections } } - private static AvailableTransport ParseAvailableTransport(JsonTextReader reader) + /// + /// + /// This method is obsolete and will be removed in a future version. + /// The recommended alternative is . + /// + /// + [Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is ParseResponse(ReadOnlySpan{byte}).")] + public static NegotiationResponse ParseResponse(Stream content) => + throw new NotSupportedException("This method is obsolete and will be removed in a future version. The recommended alternative is ParseResponse(ReadOnlySpan{byte})."); + + private static AvailableTransport ParseAvailableTransport(ref Utf8JsonReader reader) { var availableTransport = new AvailableTransport(); - while (JsonUtils.CheckRead(reader)) + while (reader.CheckRead()) { switch (reader.TokenType) { - case JsonToken.PropertyName: - var memberName = reader.Value.ToString(); + case JsonTokenType.PropertyName: + var memberName = reader.ValueSpan; - switch (memberName) + if (memberName.SequenceEqual(TransportPropertyNameBytes)) { - case TransportPropertyName: - availableTransport.Transport = JsonUtils.ReadAsString(reader, TransportPropertyName); - break; - case TransferFormatsPropertyName: - JsonUtils.CheckRead(reader); - JsonUtils.EnsureArrayStart(reader); + availableTransport.Transport = reader.ReadAsString(TransportPropertyNameBytes); + } + else if (memberName.SequenceEqual(TransferFormatsPropertyNameBytes)) + { + reader.CheckRead(); + reader.EnsureArrayStart(); - var completed = false; - availableTransport.TransferFormats = new List(); - while (!completed && JsonUtils.CheckRead(reader)) + var completed = false; + + availableTransport.TransferFormats = new List(); + while (!completed && reader.CheckRead()) + { + switch (reader.TokenType) { - switch (reader.TokenType) - { - case JsonToken.String: - availableTransport.TransferFormats.Add(reader.Value.ToString()); - break; - case JsonToken.EndArray: - completed = true; - break; - default: - throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON."); - } + case JsonTokenType.String: + availableTransport.TransferFormats.Add(reader.GetString()); + break; + case JsonTokenType.EndArray: + completed = true; + break; + default: + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON."); } - break; - default: - reader.Skip(); - break; + } + } + else + { + reader.Skip(); } break; - case JsonToken.EndObject: + case JsonTokenType.EndObject: if (availableTransport.Transport == null) { throw new InvalidDataException($"Missing required property '{TransportPropertyName}'."); diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index 02293bdc40..a01d2e637c 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; @@ -13,4 +13,4 @@ namespace Microsoft.AspNetCore.Http.Connections public IList AvailableTransports { get; set; } public string Error { get; set; } } -} \ No newline at end of file +} diff --git a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs index 39d21610e7..e92d3c3b42 100644 --- a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs +++ b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs @@ -21,8 +21,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken) { var responseData = Encoding.UTF8.GetBytes(json); - var ms = new MemoryStream(responseData); - var response = NegotiateProtocol.ParseResponse(ms); + var response = NegotiateProtocol.ParseResponse(responseData); Assert.Equal(connectionId, response.ConnectionId); Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count); @@ -48,9 +47,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage) { var responseData = Encoding.UTF8.GetBytes(payload); - var ms = new MemoryStream(responseData); - var exception = Assert.Throws(() => NegotiateProtocol.ParseResponse(ms)); + var exception = Assert.Throws(() => NegotiateProtocol.ParseResponse(responseData)); Assert.Equal(expectedMessage, exception.InnerException.Message); } @@ -69,9 +67,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests "\"LongPollDelay\":0.0}"; var responseData = Encoding.UTF8.GetBytes(payload); - var ms = new MemoryStream(responseData); - var exception = Assert.Throws(() => NegotiateProtocol.ParseResponse(ms)); + var exception = Assert.Throws(() => NegotiateProtocol.ParseResponse(responseData)); Assert.Equal("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details.", exception.InnerException.Message); } diff --git a/src/SignalR/common/Shared/SystemTextJsonExtensions.cs b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs new file mode 100644 index 0000000000..68ff75b50e --- /dev/null +++ b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs @@ -0,0 +1,102 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Text; +using System.Text.Json; + +namespace Microsoft.AspNetCore.Internal +{ + internal static class SystemTextJsonExtensions + { + public static bool CheckRead(this ref Utf8JsonReader reader) + { + if (!reader.Read()) + { + throw new InvalidDataException("Unexpected end when reading JSON."); + } + + return true; + } + + public static void EnsureObjectStart(this ref Utf8JsonReader reader) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Object."); + } + } + + public static string GetTokenString(JsonTokenType tokenType) + { + switch (tokenType) + { + case JsonTokenType.None: + break; + case JsonTokenType.StartObject: + return "Object"; + case JsonTokenType.StartArray: + return "Array"; + case JsonTokenType.PropertyName: + return "Property"; + default: + break; + } + return tokenType.ToString(); + } + + public static void EnsureArrayStart(this ref Utf8JsonReader reader) + { + if (reader.TokenType != JsonTokenType.StartArray) + { + throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Array."); + } + } + + // Remove after https://github.com/dotnet/corefx/issues/33295 is done + public static void Skip(this ref Utf8JsonReader reader) + { + if (reader.TokenType == JsonTokenType.PropertyName) + { + reader.Read(); + } + + if (reader.TokenType == JsonTokenType.StartObject || reader.TokenType == JsonTokenType.StartArray) + { + int depth = reader.CurrentDepth; + while (reader.Read() && depth < reader.CurrentDepth) + { + } + } + } + + public static string ReadAsString(this ref Utf8JsonReader reader, byte[] propertyName) + { + reader.Read(); + + if (reader.TokenType != JsonTokenType.String) + { + throw new InvalidDataException($"Expected '{Encoding.UTF8.GetString(propertyName)}' to be of type {JsonTokenType.String}."); + } + + return reader.GetString(); + } + + public static int? ReadAsInt32(this ref Utf8JsonReader reader, byte[] propertyName) + { + reader.Read(); + + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.Number) + { + throw new InvalidDataException($"Expected '{Encoding.UTF8.GetString(propertyName)}' to be of type {JsonTokenType.Number}."); + } + + return reader.GetInt32(); + } + } +} diff --git a/src/SignalR/common/Shared/TextMessageParser.cs b/src/SignalR/common/Shared/TextMessageParser.cs index 3d7d233e59..026d2a297e 100644 --- a/src/SignalR/common/Shared/TextMessageParser.cs +++ b/src/SignalR/common/Shared/TextMessageParser.cs @@ -3,12 +3,38 @@ using System; using System.Buffers; +using System.Runtime.CompilerServices; namespace Microsoft.AspNetCore.Internal { internal static class TextMessageParser { + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload) + { + if (buffer.IsSingleSegment) + { + var span = buffer.First.Span; + var index = span.IndexOf(TextMessageFormatter.RecordSeparator); + if (index == -1) + { + payload = default; + return false; + } + + payload = buffer.Slice(0, index); + + buffer = buffer.Slice(index + 1); + + return true; + } + else + { + return TryParseMessageMultiSegment(ref buffer, out payload); + } + } + + private static bool TryParseMessageMultiSegment(ref ReadOnlySequence buffer, out ReadOnlySequence payload) { var position = buffer.PositionOf(TextMessageFormatter.RecordSeparator); if (position == null) diff --git a/src/SignalR/common/SignalR.Common/src/Microsoft.AspNetCore.SignalR.Common.csproj b/src/SignalR/common/SignalR.Common/src/Microsoft.AspNetCore.SignalR.Common.csproj index b369977dc4..05f40dbebb 100644 --- a/src/SignalR/common/SignalR.Common/src/Microsoft.AspNetCore.SignalR.Common.csproj +++ b/src/SignalR/common/SignalR.Common/src/Microsoft.AspNetCore.SignalR.Common.csproj @@ -6,10 +6,11 @@ Microsoft.AspNetCore.SignalR true true + $(NoWarn);3021 - + @@ -20,8 +21,11 @@ - + + All + + diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HandshakeProtocol.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HandshakeProtocol.cs index ed1965daed..0b4cd7cefb 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HandshakeProtocol.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HandshakeProtocol.cs @@ -6,9 +6,9 @@ using System.Buffers; using System.Collections.Concurrent; using System.IO; using System.Text; +using System.Text.Json; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; -using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Protocol { @@ -18,17 +18,22 @@ namespace Microsoft.AspNetCore.SignalR.Protocol public static class HandshakeProtocol { private const string ProtocolPropertyName = "protocol"; + private static readonly byte[] ProtocolPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolPropertyName); private const string ProtocolVersionPropertyName = "version"; + private static readonly byte[] ProtocolVersionPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolVersionPropertyName); private const string MinorVersionPropertyName = "minorVersion"; + private static readonly byte[] MinorVersionPropertyNameBytes = Encoding.UTF8.GetBytes(MinorVersionPropertyName); private const string ErrorPropertyName = "error"; + private static readonly byte[] ErrorPropertyNameBytes = Encoding.UTF8.GetBytes(ErrorPropertyName); private const string TypePropertyName = "type"; + private static readonly byte[] TypePropertyNameBytes = Encoding.UTF8.GetBytes(TypePropertyName); private static ConcurrentDictionary> _messageCache = new ConcurrentDictionary>(); public static ReadOnlySpan GetSuccessfulHandshake(IHubProtocol protocol) { ReadOnlyMemory result; - if(!_messageCache.TryGetValue(protocol, out result)) + if (!_messageCache.TryGetValue(protocol, out result)) { var memoryBufferWriter = MemoryBufferWriter.Get(); try @@ -53,24 +58,13 @@ namespace Microsoft.AspNetCore.SignalR.Protocol /// The output writer. public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, IBufferWriter output) { - var textWriter = Utf8BufferTextWriter.Get(output); - try - { - using (var writer = JsonUtils.CreateJsonTextWriter(textWriter)) - { - writer.WriteStartObject(); - writer.WritePropertyName(ProtocolPropertyName); - writer.WriteValue(requestMessage.Protocol); - writer.WritePropertyName(ProtocolVersionPropertyName); - writer.WriteValue(requestMessage.Version); - writer.WriteEndObject(); - writer.Flush(); - } - } - finally - { - Utf8BufferTextWriter.Return(textWriter); - } + var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true })); + + writer.WriteStartObject(); + writer.WriteString(ProtocolPropertyNameBytes, requestMessage.Protocol, escape: false); + writer.WriteNumber(ProtocolVersionPropertyNameBytes, requestMessage.Version, escape: false); + writer.WriteEndObject(); + writer.Flush(isFinalBlock: true); TextMessageFormatter.WriteRecordSeparator(output); } @@ -82,30 +76,19 @@ namespace Microsoft.AspNetCore.SignalR.Protocol /// The output writer. public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, IBufferWriter output) { - var textWriter = Utf8BufferTextWriter.Get(output); - try - { - using (var writer = JsonUtils.CreateJsonTextWriter(textWriter)) - { - writer.WriteStartObject(); - if (!string.IsNullOrEmpty(responseMessage.Error)) - { - writer.WritePropertyName(ErrorPropertyName); - writer.WriteValue(responseMessage.Error); - } + var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true })); - writer.WritePropertyName(MinorVersionPropertyName); - writer.WriteValue(responseMessage.MinorVersion); - - writer.WriteEndObject(); - writer.Flush(); - } - } - finally + writer.WriteStartObject(); + if (!string.IsNullOrEmpty(responseMessage.Error)) { - Utf8BufferTextWriter.Return(textWriter); + writer.WriteString(ErrorPropertyNameBytes, responseMessage.Error); } + writer.WriteNumber(MinorVersionPropertyNameBytes, responseMessage.MinorVersion, escape: false); + + writer.WriteEndObject(); + writer.Flush(isFinalBlock: true); + TextMessageFormatter.WriteRecordSeparator(output); } @@ -123,59 +106,51 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return false; } - var textReader = Utf8BufferTextReader.Get(payload); + var reader = new Utf8JsonReader(in payload, isFinalBlock: true, state: default); - try + reader.CheckRead(); + reader.EnsureObjectStart(); + + int? minorVersion = null; + string error = null; + + while (reader.CheckRead()) { - using (var reader = JsonUtils.CreateJsonTextReader(textReader)) + if (reader.TokenType == JsonTokenType.PropertyName) { - JsonUtils.CheckRead(reader); - JsonUtils.EnsureObjectStart(reader); + var memberName = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan; - int? minorVersion = null; - string error = null; - - var completed = false; - while (!completed && JsonUtils.CheckRead(reader)) + if (memberName.SequenceEqual(TypePropertyNameBytes)) { - switch (reader.TokenType) - { - case JsonToken.PropertyName: - var memberName = reader.Value.ToString(); - - switch (memberName) - { - case TypePropertyName: - // a handshake response does not have a type - // check the incoming message was not any other type of message - throw new InvalidDataException("Expected a handshake response from the server."); - case ErrorPropertyName: - error = JsonUtils.ReadAsString(reader, ErrorPropertyName); - break; - case MinorVersionPropertyName: - minorVersion = JsonUtils.ReadAsInt32(reader, MinorVersionPropertyName); - break; - default: - reader.Skip(); - break; - } - break; - case JsonToken.EndObject: - completed = true; - break; - default: - throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON."); - } - }; - - responseMessage = new HandshakeResponseMessage(minorVersion, error); - return true; + // a handshake response does not have a type + // check the incoming message was not any other type of message + throw new InvalidDataException("Expected a handshake response from the server."); + } + else if (memberName.SequenceEqual(ErrorPropertyNameBytes)) + { + error = reader.ReadAsString(ErrorPropertyNameBytes); + } + else if (memberName.SequenceEqual(MinorVersionPropertyNameBytes)) + { + minorVersion = reader.ReadAsInt32(MinorVersionPropertyNameBytes); + } + else + { + reader.Skip(); + } } - } - finally - { - Utf8BufferTextReader.Return(textReader); - } + else if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + else + { + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON."); + } + }; + + responseMessage = new HandshakeResponseMessage(minorVersion, error); + return true; } /// @@ -192,62 +167,53 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return false; } - var textReader = Utf8BufferTextReader.Get(payload); - try + var reader = new Utf8JsonReader(in payload, isFinalBlock: true, state: default); + + reader.CheckRead(); + reader.EnsureObjectStart(); + + string protocol = null; + int? protocolVersion = null; + + while (reader.CheckRead()) { - using (var reader = JsonUtils.CreateJsonTextReader(textReader)) + if (reader.TokenType == JsonTokenType.PropertyName) { - JsonUtils.CheckRead(reader); - JsonUtils.EnsureObjectStart(reader); + var memberName = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan; - string protocol = null; - int? protocolVersion = null; - - var completed = false; - while (!completed && JsonUtils.CheckRead(reader)) + if (memberName.SequenceEqual(ProtocolPropertyNameBytes)) { - switch (reader.TokenType) - { - case JsonToken.PropertyName: - var memberName = reader.Value.ToString(); - - switch (memberName) - { - case ProtocolPropertyName: - protocol = JsonUtils.ReadAsString(reader, ProtocolPropertyName); - break; - case ProtocolVersionPropertyName: - protocolVersion = JsonUtils.ReadAsInt32(reader, ProtocolVersionPropertyName); - break; - default: - reader.Skip(); - break; - } - break; - case JsonToken.EndObject: - completed = true; - break; - default: - throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON. Message content: {GetPayloadAsString()}"); - } + protocol = reader.ReadAsString(ProtocolPropertyNameBytes); } - - if (protocol == null) + else if (memberName.SequenceEqual(ProtocolVersionPropertyNameBytes)) { - throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'. Message content: {GetPayloadAsString()}"); + protocolVersion = reader.ReadAsInt32(ProtocolVersionPropertyNameBytes); } - if (protocolVersion == null) + else { - throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'. Message content: {GetPayloadAsString()}"); + reader.Skip(); } - - requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value); + } + else if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + else + { + throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON. Message content: {GetPayloadAsString()}"); } } - finally + + if (protocol == null) { - Utf8BufferTextReader.Return(textReader); + throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'. Message content: {GetPayloadAsString()}"); } + if (protocolVersion == null) + { + throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'. Message content: {GetPayloadAsString()}"); + } + + requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value); // For error messages, we want to print the payload as text string GetPayloadAsString() diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/HandshakeProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/HandshakeProtocolTests.cs index b2c9aaa65d..58bc7b2b70 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/HandshakeProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/HandshakeProtocolTests.cs @@ -25,6 +25,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Equal(version, deserializedMessage.Version); } + [Fact] + public void ParsingHandshakeRequestMessageSuccessForValidMessageWithMultipleSegments() + { + var message = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent("{\"protocol\":\"json\",\"version\":1}\u001e"); + + Assert.True(HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)); + + Assert.Equal("json", deserializedMessage.Protocol); + Assert.Equal(1, deserializedMessage.Version); + } + [Theory] [InlineData("{\"error\":\"dummy\"}\u001e", "dummy")] [InlineData("{\"error\":\"\"}\u001e", "")] @@ -38,6 +49,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Equal(error, response.Error); } + [Fact] + public void ParsingHandshakeResponseMessageSuccessForValidMessageWithMultipleSegments() + { + var message = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent("{\"error\":\"dummy\"}\u001e"); + + Assert.True(HandshakeProtocol.TryParseResponseMessage(ref message, out var response)); + + Assert.Equal("dummy", response.Error); + } + [Theory] [InlineData("{\"error\":\"\",\"minorVersion\":34}\u001e", 34)] [InlineData("{\"error\":\"flump flump flump\",\"minorVersion\":112}\u001e", 112)] @@ -58,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } [Theory] - [InlineData("42\u001e", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] + [InlineData("42\u001e", "Unexpected JSON Token Type 'Number'. Expected a JSON Object.")] [InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] [InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] [InlineData("{}\u001e", "Missing required property 'protocol'. Message content: {}")] @@ -66,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData("{\"protocol\":\"json\"}\u001e", "Missing required property 'version'. Message content: {\"protocol\":\"json\"}")] [InlineData("{\"version\":1}\u001e", "Missing required property 'protocol'. Message content: {\"version\":1}")] [InlineData("{\"type\":4,\"invocationId\":\"42\",\"target\":\"foo\",\"arguments\":{}}\u001e", "Missing required property 'protocol'. Message content: {\"type\":4,\"invocationId\":\"42\",\"target\":\"foo\",\"arguments\":{}}")] - [InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")] + [InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Number.")] [InlineData("{\"protocol\":null,\"version\":123}\u001e", "Expected 'protocol' to be of type String.")] public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage) { @@ -79,7 +100,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } [Theory] - [InlineData("42\u001e", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] + [InlineData("42\u001e", "Unexpected JSON Token Type 'Number'. Expected a JSON Object.")] [InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] [InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] [InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] diff --git a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index 50d78b970f..39b8b2ad4b 100644 --- a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -7,6 +7,7 @@ + diff --git a/src/SignalR/perf/Microbenchmarks/HandshakeProtocolBenchmark.cs b/src/SignalR/perf/Microbenchmarks/HandshakeProtocolBenchmark.cs index 2a36dc648e..c14b6125d0 100644 --- a/src/SignalR/perf/Microbenchmarks/HandshakeProtocolBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/HandshakeProtocolBenchmark.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks ReadOnlySequence _requestMessage2; ReadOnlySequence _requestMessage3; ReadOnlySequence _requestMessage4; + ReadOnlySequence _requestMessage5; ReadOnlySequence _responseMessage1; ReadOnlySequence _responseMessage2; @@ -31,6 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _requestMessage2 = new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"protocol\":\"\",\"version\":10}\u001e")); _requestMessage3 = new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"protocol\":\"\",\"version\":10,\"unknown\":null}\u001e")); _requestMessage4 = new ReadOnlySequence(Encoding.UTF8.GetBytes("42")); + _requestMessage5 = ReadOnlySequenceFactory.CreateSegments(Encoding.UTF8.GetBytes("{\"protocol\":\"dummy\",\"ver"), Encoding.UTF8.GetBytes("sion\":1}\u001e")); _responseMessage1 = new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"error\":\"dummy\"}\u001e")); _responseMessage2 = new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"error\":\"\"}\u001e")); @@ -85,43 +87,112 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } [Benchmark] - public void ParsingHandshakeRequestMessage_ValidMessage1() - => HandshakeProtocol.TryParseRequestMessage(ref _requestMessage1, out var deserializedMessage); + public void ParsingHandshakeRequestMessage_ValidMessage1() { + var message = _requestMessage1; + if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeRequestMessage_ValidMessage2() - => HandshakeProtocol.TryParseRequestMessage(ref _requestMessage2, out var deserializedMessage); + { + var message = _requestMessage2; + if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeRequestMessage_ValidMessage3() - => HandshakeProtocol.TryParseRequestMessage(ref _requestMessage3, out var deserializedMessage); + { + var message = _requestMessage3; + if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeRequestMessage_NotComplete1() - => HandshakeProtocol.TryParseRequestMessage(ref _requestMessage4, out _); + { + var message = _requestMessage4; + if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } + + [Benchmark] + public void ParsingHandshakeRequestMessage_ValidMessageSegments() + { + var message = _requestMessage5; + if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_ValidMessages1() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage1, out var response); + { + var message = _responseMessage1; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_ValidMessages2() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage2, out var response); + { + var message = _responseMessage2; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_ValidMessages3() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage3, out var response); + { + var message = _responseMessage3; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_ValidMessages4() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage4, out var response); + { + var message = _responseMessage4; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_GivesMinorVersion1() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage5, out var response); + { + var message = _responseMessage5; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } [Benchmark] public void ParsingHandshakeResponseMessage_GivesMinorVersion2() - => HandshakeProtocol.TryParseResponseMessage(ref _responseMessage6, out var response); + { + var message = _responseMessage6; + if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage)) + { + throw new Exception(); + } + } } } \ No newline at end of file diff --git a/src/SignalR/perf/Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj b/src/SignalR/perf/Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj index 534419510c..3c2343fb39 100644 --- a/src/SignalR/perf/Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj +++ b/src/SignalR/perf/Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj @@ -8,6 +8,7 @@ + diff --git a/src/SignalR/perf/Microbenchmarks/NegotiateProtocolBenchmark.cs b/src/SignalR/perf/Microbenchmarks/NegotiateProtocolBenchmark.cs index 909b4217d0..c0c67a0742 100644 --- a/src/SignalR/perf/Microbenchmarks/NegotiateProtocolBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/NegotiateProtocolBenchmark.cs @@ -70,22 +70,22 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Benchmark] public void ParsingNegotiateResponseMessageSuccessForValid1() - => NegotiateProtocol.ParseResponse(new MemoryStream(_responseData1)); + => NegotiateProtocol.ParseResponse(_responseData1); [Benchmark] public void ParsingNegotiateResponseMessageSuccessForValid2() - => NegotiateProtocol.ParseResponse(new MemoryStream(_responseData2)); + => NegotiateProtocol.ParseResponse(_responseData2); [Benchmark] public void ParsingNegotiateResponseMessageSuccessForValid3() - => NegotiateProtocol.ParseResponse(new MemoryStream(_responseData3)); + => NegotiateProtocol.ParseResponse(_responseData3); [Benchmark] public void ParsingNegotiateResponseMessageSuccessForValid4() - => NegotiateProtocol.ParseResponse(new MemoryStream(_responseData4)); + => NegotiateProtocol.ParseResponse(_responseData4); [Benchmark] public void ParsingNegotiateResponseMessageSuccessForValid5() - => NegotiateProtocol.ParseResponse(new MemoryStream(_responseData5)); + => NegotiateProtocol.ParseResponse(_responseData5); } } \ No newline at end of file