Use System.Text.Json for Negotiate and Handshake (#6977)

This commit is contained in:
BrennanConroy 2019-01-31 11:38:09 -08:00 committed by GitHub
parent 7d21ee1a5a
commit dbf82dc8c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 535 additions and 326 deletions

View File

@ -39,6 +39,7 @@ and are generated based on the last package release.
<LatestPackageReference Include="Microsoft.AspNetCore.Razor.Language" Version="$(MicrosoftAspNetCoreRazorLanguagePackageVersion)" />
<LatestPackageReference Include="Microsoft.AspNetCore.Testing" Version="$(MicrosoftAspNetCoreTestingPackageVersion)" />
<LatestPackageReference Include="Microsoft.Azure.KeyVault" Version="$(MicrosoftAzureKeyVaultPackageVersion)" />
<LatestPackageReference Include="Microsoft.Bcl.Json.Sources" Version="$(MicrosoftBclJsonSourcesPacakgeVersion)" />
<LatestPackageReference Include="Microsoft.Build.Framework" Version="$(MicrosoftBuildFrameworkPackageVersion)" />
<LatestPackageReference Include="Microsoft.Build.Utilities.Core" Version="$(MicrosoftBuildUtilitiesCorePackageVersion)" />
<LatestPackageReference Include="Microsoft.CodeAnalysis.Common" Version="$(MicrosoftCodeAnalysisCommonPackageVersion)" />

View File

@ -13,6 +13,7 @@
<MicrosoftNETCoreAppPackageVersion>3.0.0-preview-27324-5</MicrosoftNETCoreAppPackageVersion>
<MicrosoftDotNetPlatformAbstractionsPackageVersion>3.0.0-preview-27324-5</MicrosoftDotNetPlatformAbstractionsPackageVersion>
<!-- Packages from dotnet/corefx -->
<MicrosoftBclJsonSourcesPacakgeVersion>4.6.0-preview.19073.11</MicrosoftBclJsonSourcesPacakgeVersion>
<MicrosoftCSharpPackageVersion>4.6.0-preview.19073.11</MicrosoftCSharpPackageVersion>
<MicrosoftWin32RegistryPackageVersion>4.6.0-preview.19073.11</MicrosoftWin32RegistryPackageVersion>
<SystemComponentModelAnnotationsPackageVersion>4.6.0-preview.19073.11</SystemComponentModelAnnotationsPackageVersion>

View File

@ -439,11 +439,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
using (var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead))
{
response.EnsureSuccessStatusCode();
NegotiationResponse negotiateResponse;
using (var responseStream = await response.Content.ReadAsStreamAsync())
{
negotiateResponse = NegotiateProtocol.ParseResponse(responseStream);
}
var responseBuffer = await response.Content.ReadAsByteArrayAsync();
var negotiateResponse = NegotiateProtocol.ParseResponse(responseBuffer);
if (!string.IsNullOrEmpty(negotiateResponse.Error))
{
throw new Exception(negotiateResponse.Error);

View File

@ -2,20 +2,24 @@
<PropertyGroup>
<Description>Common primitives for ASP.NET Connection Handlers and clients</Description>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0;netcoreapp3.0</TargetFrameworks>
<RootNamespace>Microsoft.AspNetCore.Http.Connections</RootNamespace>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<IsShippingPackage>true</IsShippingPackage>
<NoWarn>$(NoWarn);3021</NoWarn>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(SignalRSharedSourceRoot)JsonUtils.cs" Link="Internal\JsonUtils.cs" />
<Compile Include="$(SignalRSharedSourceRoot)Utf8BufferTextWriter.cs" Link="Internal\Utf8BufferTextWriter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)SystemTextJsonExtensions.cs" Link="Internal\SystemTextJsonExtensions.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Connections.Abstractions" />
<Reference Include="Newtonsoft.Json" />
<Reference Include="System.Buffers" />
<Reference Include="Microsoft.Bcl.Json.Sources" Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PrivateAssets>All</PrivateAssets>
</Reference>
<Reference Include="System.Runtime.CompilerServices.Unsafe" Condition="'$(TargetFramework)' == 'netstandard2.0'" />
</ItemGroup>
</Project>

View File

@ -5,180 +5,186 @@ using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Text.Json;
using Microsoft.AspNetCore.Internal;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.Http.Connections
{
public static class NegotiateProtocol
{
private const string ConnectionIdPropertyName = "connectionId";
private static readonly byte[] ConnectionIdPropertyNameBytes = Encoding.UTF8.GetBytes(ConnectionIdPropertyName);
private const string UrlPropertyName = "url";
private static readonly byte[] UrlPropertyNameBytes = Encoding.UTF8.GetBytes(UrlPropertyName);
private const string AccessTokenPropertyName = "accessToken";
private static readonly byte[] AccessTokenPropertyNameBytes = Encoding.UTF8.GetBytes(AccessTokenPropertyName);
private const string AvailableTransportsPropertyName = "availableTransports";
private static readonly byte[] AvailableTransportsPropertyNameBytes = Encoding.UTF8.GetBytes(AvailableTransportsPropertyName);
private const string TransportPropertyName = "transport";
private static readonly byte[] TransportPropertyNameBytes = Encoding.UTF8.GetBytes(TransportPropertyName);
private const string TransferFormatsPropertyName = "transferFormats";
private static readonly byte[] TransferFormatsPropertyNameBytes = Encoding.UTF8.GetBytes(TransferFormatsPropertyName);
private const string ErrorPropertyName = "error";
private static readonly byte[] ErrorPropertyNameBytes = Encoding.UTF8.GetBytes(ErrorPropertyName);
// Used to detect ASP.NET SignalR Server connection attempt
private const string ProtocolVersionPropertyName = "ProtocolVersion";
private static readonly byte[] ProtocolVersionPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolVersionPropertyName);
public static void WriteResponse(NegotiationResponse response, IBufferWriter<byte> output)
{
var textWriter = Utf8BufferTextWriter.Get(output);
try
var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true }));
writer.WriteStartObject();
if (!string.IsNullOrEmpty(response.Url))
{
using (var jsonWriter = JsonUtils.CreateJsonTextWriter(textWriter))
writer.WriteString(UrlPropertyNameBytes, response.Url, escape: false);
}
if (!string.IsNullOrEmpty(response.AccessToken))
{
writer.WriteString(AccessTokenPropertyNameBytes, response.AccessToken, escape: false);
}
if (!string.IsNullOrEmpty(response.ConnectionId))
{
writer.WriteString(ConnectionIdPropertyNameBytes, response.ConnectionId, escape: false);
}
writer.WriteStartArray(AvailableTransportsPropertyNameBytes, escape: false);
if (response.AvailableTransports != null)
{
foreach (var availableTransport in response.AvailableTransports)
{
jsonWriter.WriteStartObject();
if (!string.IsNullOrEmpty(response.Url))
writer.WriteStartObject();
if (availableTransport.Transport != null)
{
jsonWriter.WritePropertyName(UrlPropertyName);
jsonWriter.WriteValue(response.Url);
writer.WriteString(TransportPropertyNameBytes, availableTransport.Transport, escape: false);
}
if (!string.IsNullOrEmpty(response.AccessToken))
else
{
jsonWriter.WritePropertyName(AccessTokenPropertyName);
jsonWriter.WriteValue(response.AccessToken);
// Might be able to remove this after https://github.com/dotnet/corefx/issues/34632 is resolved
writer.WriteNull(TransportPropertyNameBytes, escape: false);
}
writer.WriteStartArray(TransferFormatsPropertyNameBytes, escape: false);
if (!string.IsNullOrEmpty(response.ConnectionId))
if (availableTransport.TransferFormats != null)
{
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
jsonWriter.WriteValue(response.ConnectionId);
}
jsonWriter.WritePropertyName(AvailableTransportsPropertyName);
jsonWriter.WriteStartArray();
if (response.AvailableTransports != null)
{
foreach (var availableTransport in response.AvailableTransports)
foreach (var transferFormat in availableTransport.TransferFormats)
{
jsonWriter.WriteStartObject();
jsonWriter.WritePropertyName(TransportPropertyName);
jsonWriter.WriteValue(availableTransport.Transport);
jsonWriter.WritePropertyName(TransferFormatsPropertyName);
jsonWriter.WriteStartArray();
if (availableTransport.TransferFormats != null)
{
foreach (var transferFormat in availableTransport.TransferFormats)
{
jsonWriter.WriteValue(transferFormat);
}
}
jsonWriter.WriteEndArray();
jsonWriter.WriteEndObject();
writer.WriteStringValue(transferFormat, escape: false);
}
}
jsonWriter.WriteEndArray();
jsonWriter.WriteEndObject();
jsonWriter.Flush();
writer.WriteEndArray();
writer.WriteEndObject();
}
}
finally
{
Utf8BufferTextWriter.Return(textWriter);
}
writer.WriteEndArray();
writer.WriteEndObject();
writer.Flush(isFinalBlock: true);
}
public static NegotiationResponse ParseResponse(Stream content)
public static NegotiationResponse ParseResponse(ReadOnlySpan<byte> content)
{
try
{
using (var reader = JsonUtils.CreateJsonTextReader(new StreamReader(content)))
var reader = new Utf8JsonReader(content, isFinalBlock: true, state: default);
reader.CheckRead();
reader.EnsureObjectStart();
string connectionId = null;
string url = null;
string accessToken = null;
List<AvailableTransport> availableTransports = null;
string error = null;
var completed = false;
while (!completed && reader.CheckRead())
{
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
string connectionId = null;
string url = null;
string accessToken = null;
List<AvailableTransport> availableTransports = null;
string error = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
switch (reader.TokenType)
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
var memberName = reader.Value.ToString();
case JsonTokenType.PropertyName:
var memberName = reader.ValueSpan;
switch (memberName)
if (memberName.SequenceEqual(UrlPropertyNameBytes))
{
url = reader.ReadAsString(UrlPropertyNameBytes);
}
else if (memberName.SequenceEqual(AccessTokenPropertyNameBytes))
{
accessToken = reader.ReadAsString(AccessTokenPropertyNameBytes);
}
else if (memberName.SequenceEqual(ConnectionIdPropertyNameBytes))
{
connectionId = reader.ReadAsString(ConnectionIdPropertyNameBytes);
}
else if (memberName.SequenceEqual(AvailableTransportsPropertyNameBytes))
{
reader.CheckRead();
reader.EnsureArrayStart();
availableTransports = new List<AvailableTransport>();
while (reader.CheckRead())
{
case UrlPropertyName:
url = JsonUtils.ReadAsString(reader, UrlPropertyName);
break;
case AccessTokenPropertyName:
accessToken = JsonUtils.ReadAsString(reader, AccessTokenPropertyName);
break;
case ConnectionIdPropertyName:
connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName);
break;
case AvailableTransportsPropertyName:
JsonUtils.CheckRead(reader);
JsonUtils.EnsureArrayStart(reader);
availableTransports = new List<AvailableTransport>();
while (JsonUtils.CheckRead(reader))
{
if (reader.TokenType == JsonToken.StartObject)
{
availableTransports.Add(ParseAvailableTransport(reader));
}
else if (reader.TokenType == JsonToken.EndArray)
{
break;
}
}
break;
case ErrorPropertyName:
error = JsonUtils.ReadAsString(reader, ErrorPropertyName);
break;
case ProtocolVersionPropertyName:
throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details.");
default:
reader.Skip();
if (reader.TokenType == JsonTokenType.StartObject)
{
availableTransports.Add(ParseAvailableTransport(ref reader));
}
else if (reader.TokenType == JsonTokenType.EndArray)
{
break;
}
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON.");
}
}
else if (memberName.SequenceEqual(ErrorPropertyNameBytes))
{
error = reader.ReadAsString(ErrorPropertyNameBytes);
}
else if (memberName.SequenceEqual(ProtocolVersionPropertyNameBytes))
{
throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details.");
}
else
{
reader.Skip();
}
break;
case JsonTokenType.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON.");
}
if (url == null && error == null)
{
// if url isn't specified or there isn't an error, connectionId and available transports are required
if (connectionId == null)
{
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
}
if (availableTransports == null)
{
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
}
}
return new NegotiationResponse
{
ConnectionId = connectionId,
Url = url,
AccessToken = accessToken,
AvailableTransports = availableTransports,
Error = error,
};
}
if (url == null && error == null)
{
// if url isn't specified or there isn't an error, connectionId and available transports are required
if (connectionId == null)
{
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
}
if (availableTransports == null)
{
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
}
}
return new NegotiationResponse
{
ConnectionId = connectionId,
Url = url,
AccessToken = accessToken,
AvailableTransports = availableTransports,
Error = error,
};
}
catch (Exception ex)
{
@ -186,49 +192,60 @@ namespace Microsoft.AspNetCore.Http.Connections
}
}
private static AvailableTransport ParseAvailableTransport(JsonTextReader reader)
/// <summary>
/// <para>
/// This method is obsolete and will be removed in a future version.
/// The recommended alternative is <see cref="ParseResponse(ReadOnlySpan{byte})" />.
/// </para>
/// </summary>
[Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is ParseResponse(ReadOnlySpan{byte}).")]
public static NegotiationResponse ParseResponse(Stream content) =>
throw new NotSupportedException("This method is obsolete and will be removed in a future version. The recommended alternative is ParseResponse(ReadOnlySpan{byte}).");
private static AvailableTransport ParseAvailableTransport(ref Utf8JsonReader reader)
{
var availableTransport = new AvailableTransport();
while (JsonUtils.CheckRead(reader))
while (reader.CheckRead())
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
var memberName = reader.Value.ToString();
case JsonTokenType.PropertyName:
var memberName = reader.ValueSpan;
switch (memberName)
if (memberName.SequenceEqual(TransportPropertyNameBytes))
{
case TransportPropertyName:
availableTransport.Transport = JsonUtils.ReadAsString(reader, TransportPropertyName);
break;
case TransferFormatsPropertyName:
JsonUtils.CheckRead(reader);
JsonUtils.EnsureArrayStart(reader);
availableTransport.Transport = reader.ReadAsString(TransportPropertyNameBytes);
}
else if (memberName.SequenceEqual(TransferFormatsPropertyNameBytes))
{
reader.CheckRead();
reader.EnsureArrayStart();
var completed = false;
availableTransport.TransferFormats = new List<string>();
while (!completed && JsonUtils.CheckRead(reader))
var completed = false;
availableTransport.TransferFormats = new List<string>();
while (!completed && reader.CheckRead())
{
switch (reader.TokenType)
{
switch (reader.TokenType)
{
case JsonToken.String:
availableTransport.TransferFormats.Add(reader.Value.ToString());
break;
case JsonToken.EndArray:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON.");
}
case JsonTokenType.String:
availableTransport.TransferFormats.Add(reader.GetString());
break;
case JsonTokenType.EndArray:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON.");
}
break;
default:
reader.Skip();
break;
}
}
else
{
reader.Skip();
}
break;
case JsonToken.EndObject:
case JsonTokenType.EndObject:
if (availableTransport.Transport == null)
{
throw new InvalidDataException($"Missing required property '{TransportPropertyName}'.");

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
@ -13,4 +13,4 @@ namespace Microsoft.AspNetCore.Http.Connections
public IList<AvailableTransport> AvailableTransports { get; set; }
public string Error { get; set; }
}
}
}

View File

@ -21,8 +21,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken)
{
var responseData = Encoding.UTF8.GetBytes(json);
var ms = new MemoryStream(responseData);
var response = NegotiateProtocol.ParseResponse(ms);
var response = NegotiateProtocol.ParseResponse(responseData);
Assert.Equal(connectionId, response.ConnectionId);
Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count);
@ -48,9 +47,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage)
{
var responseData = Encoding.UTF8.GetBytes(payload);
var ms = new MemoryStream(responseData);
var exception = Assert.Throws<InvalidDataException>(() => NegotiateProtocol.ParseResponse(ms));
var exception = Assert.Throws<InvalidDataException>(() => NegotiateProtocol.ParseResponse(responseData));
Assert.Equal(expectedMessage, exception.InnerException.Message);
}
@ -69,9 +67,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
"\"LongPollDelay\":0.0}";
var responseData = Encoding.UTF8.GetBytes(payload);
var ms = new MemoryStream(responseData);
var exception = Assert.Throws<InvalidDataException>(() => NegotiateProtocol.ParseResponse(ms));
var exception = Assert.Throws<InvalidDataException>(() => NegotiateProtocol.ParseResponse(responseData));
Assert.Equal("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details.", exception.InnerException.Message);
}

View File

@ -0,0 +1,102 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Text;
using System.Text.Json;
namespace Microsoft.AspNetCore.Internal
{
internal static class SystemTextJsonExtensions
{
public static bool CheckRead(this ref Utf8JsonReader reader)
{
if (!reader.Read())
{
throw new InvalidDataException("Unexpected end when reading JSON.");
}
return true;
}
public static void EnsureObjectStart(this ref Utf8JsonReader reader)
{
if (reader.TokenType != JsonTokenType.StartObject)
{
throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Object.");
}
}
public static string GetTokenString(JsonTokenType tokenType)
{
switch (tokenType)
{
case JsonTokenType.None:
break;
case JsonTokenType.StartObject:
return "Object";
case JsonTokenType.StartArray:
return "Array";
case JsonTokenType.PropertyName:
return "Property";
default:
break;
}
return tokenType.ToString();
}
public static void EnsureArrayStart(this ref Utf8JsonReader reader)
{
if (reader.TokenType != JsonTokenType.StartArray)
{
throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Array.");
}
}
// Remove after https://github.com/dotnet/corefx/issues/33295 is done
public static void Skip(this ref Utf8JsonReader reader)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
reader.Read();
}
if (reader.TokenType == JsonTokenType.StartObject || reader.TokenType == JsonTokenType.StartArray)
{
int depth = reader.CurrentDepth;
while (reader.Read() && depth < reader.CurrentDepth)
{
}
}
}
public static string ReadAsString(this ref Utf8JsonReader reader, byte[] propertyName)
{
reader.Read();
if (reader.TokenType != JsonTokenType.String)
{
throw new InvalidDataException($"Expected '{Encoding.UTF8.GetString(propertyName)}' to be of type {JsonTokenType.String}.");
}
return reader.GetString();
}
public static int? ReadAsInt32(this ref Utf8JsonReader reader, byte[] propertyName)
{
reader.Read();
if (reader.TokenType == JsonTokenType.Null)
{
return null;
}
if (reader.TokenType != JsonTokenType.Number)
{
throw new InvalidDataException($"Expected '{Encoding.UTF8.GetString(propertyName)}' to be of type {JsonTokenType.Number}.");
}
return reader.GetInt32();
}
}
}

View File

@ -3,12 +3,38 @@
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
namespace Microsoft.AspNetCore.Internal
{
internal static class TextMessageParser
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool TryParseMessage(ref ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> payload)
{
if (buffer.IsSingleSegment)
{
var span = buffer.First.Span;
var index = span.IndexOf(TextMessageFormatter.RecordSeparator);
if (index == -1)
{
payload = default;
return false;
}
payload = buffer.Slice(0, index);
buffer = buffer.Slice(index + 1);
return true;
}
else
{
return TryParseMessageMultiSegment(ref buffer, out payload);
}
}
private static bool TryParseMessageMultiSegment(ref ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> payload)
{
var position = buffer.PositionOf(TextMessageFormatter.RecordSeparator);
if (position == null)

View File

@ -6,10 +6,11 @@
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<IsShippingPackage>true</IsShippingPackage>
<NoWarn>$(NoWarn);3021</NoWarn>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(SignalRSharedSourceRoot)JsonUtils.cs" Link="Internal\JsonUtils.cs" />
<Compile Include="$(SignalRSharedSourceRoot)SystemTextJsonExtensions.cs" Link="Internal\SystemTextJsonExtensions.cs" />
<Compile Include="$(SignalRSharedSourceRoot)MemoryBufferWriter.cs" Link="Internal\MemoryBufferWriter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TextMessageFormatter.cs" Link="Internal\TextMessageFormatter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TextMessageParser.cs" Link="Internal\TextMessageParser.cs" />
@ -20,8 +21,11 @@
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Connections.Abstractions" />
<Reference Include="Microsoft.Extensions.Options" />
<Reference Include="Newtonsoft.Json" />
<Reference Include="System.Buffers" />
<Reference Include="Microsoft.Bcl.Json.Sources" Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PrivateAssets>All</PrivateAssets>
</Reference>
<Reference Include="System.Runtime.CompilerServices.Unsafe" Condition="'$(TargetFramework)' == 'netstandard2.0'" />
</ItemGroup>
</Project>

View File

@ -6,9 +6,9 @@ using System.Buffers;
using System.Collections.Concurrent;
using System.IO;
using System.Text;
using System.Text.Json;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Internal;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Protocol
{
@ -18,17 +18,22 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
public static class HandshakeProtocol
{
private const string ProtocolPropertyName = "protocol";
private static readonly byte[] ProtocolPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolPropertyName);
private const string ProtocolVersionPropertyName = "version";
private static readonly byte[] ProtocolVersionPropertyNameBytes = Encoding.UTF8.GetBytes(ProtocolVersionPropertyName);
private const string MinorVersionPropertyName = "minorVersion";
private static readonly byte[] MinorVersionPropertyNameBytes = Encoding.UTF8.GetBytes(MinorVersionPropertyName);
private const string ErrorPropertyName = "error";
private static readonly byte[] ErrorPropertyNameBytes = Encoding.UTF8.GetBytes(ErrorPropertyName);
private const string TypePropertyName = "type";
private static readonly byte[] TypePropertyNameBytes = Encoding.UTF8.GetBytes(TypePropertyName);
private static ConcurrentDictionary<IHubProtocol, ReadOnlyMemory<byte>> _messageCache = new ConcurrentDictionary<IHubProtocol, ReadOnlyMemory<byte>>();
public static ReadOnlySpan<byte> GetSuccessfulHandshake(IHubProtocol protocol)
{
ReadOnlyMemory<byte> result;
if(!_messageCache.TryGetValue(protocol, out result))
if (!_messageCache.TryGetValue(protocol, out result))
{
var memoryBufferWriter = MemoryBufferWriter.Get();
try
@ -53,24 +58,13 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
/// <param name="output">The output writer.</param>
public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, IBufferWriter<byte> output)
{
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
{
writer.WriteStartObject();
writer.WritePropertyName(ProtocolPropertyName);
writer.WriteValue(requestMessage.Protocol);
writer.WritePropertyName(ProtocolVersionPropertyName);
writer.WriteValue(requestMessage.Version);
writer.WriteEndObject();
writer.Flush();
}
}
finally
{
Utf8BufferTextWriter.Return(textWriter);
}
var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true }));
writer.WriteStartObject();
writer.WriteString(ProtocolPropertyNameBytes, requestMessage.Protocol, escape: false);
writer.WriteNumber(ProtocolVersionPropertyNameBytes, requestMessage.Version, escape: false);
writer.WriteEndObject();
writer.Flush(isFinalBlock: true);
TextMessageFormatter.WriteRecordSeparator(output);
}
@ -82,30 +76,19 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
/// <param name="output">The output writer.</param>
public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, IBufferWriter<byte> output)
{
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
{
writer.WriteStartObject();
if (!string.IsNullOrEmpty(responseMessage.Error))
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(responseMessage.Error);
}
var writer = new Utf8JsonWriter(output, new JsonWriterState(new JsonWriterOptions() { SkipValidation = true }));
writer.WritePropertyName(MinorVersionPropertyName);
writer.WriteValue(responseMessage.MinorVersion);
writer.WriteEndObject();
writer.Flush();
}
}
finally
writer.WriteStartObject();
if (!string.IsNullOrEmpty(responseMessage.Error))
{
Utf8BufferTextWriter.Return(textWriter);
writer.WriteString(ErrorPropertyNameBytes, responseMessage.Error);
}
writer.WriteNumber(MinorVersionPropertyNameBytes, responseMessage.MinorVersion, escape: false);
writer.WriteEndObject();
writer.Flush(isFinalBlock: true);
TextMessageFormatter.WriteRecordSeparator(output);
}
@ -123,59 +106,51 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return false;
}
var textReader = Utf8BufferTextReader.Get(payload);
var reader = new Utf8JsonReader(in payload, isFinalBlock: true, state: default);
try
reader.CheckRead();
reader.EnsureObjectStart();
int? minorVersion = null;
string error = null;
while (reader.CheckRead())
{
using (var reader = JsonUtils.CreateJsonTextReader(textReader))
if (reader.TokenType == JsonTokenType.PropertyName)
{
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
var memberName = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan;
int? minorVersion = null;
string error = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
if (memberName.SequenceEqual(TypePropertyNameBytes))
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
var memberName = reader.Value.ToString();
switch (memberName)
{
case TypePropertyName:
// a handshake response does not have a type
// check the incoming message was not any other type of message
throw new InvalidDataException("Expected a handshake response from the server.");
case ErrorPropertyName:
error = JsonUtils.ReadAsString(reader, ErrorPropertyName);
break;
case MinorVersionPropertyName:
minorVersion = JsonUtils.ReadAsInt32(reader, MinorVersionPropertyName);
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON.");
}
};
responseMessage = new HandshakeResponseMessage(minorVersion, error);
return true;
// a handshake response does not have a type
// check the incoming message was not any other type of message
throw new InvalidDataException("Expected a handshake response from the server.");
}
else if (memberName.SequenceEqual(ErrorPropertyNameBytes))
{
error = reader.ReadAsString(ErrorPropertyNameBytes);
}
else if (memberName.SequenceEqual(MinorVersionPropertyNameBytes))
{
minorVersion = reader.ReadAsInt32(MinorVersionPropertyNameBytes);
}
else
{
reader.Skip();
}
}
}
finally
{
Utf8BufferTextReader.Return(textReader);
}
else if (reader.TokenType == JsonTokenType.EndObject)
{
break;
}
else
{
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON.");
}
};
responseMessage = new HandshakeResponseMessage(minorVersion, error);
return true;
}
/// <summary>
@ -192,62 +167,53 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return false;
}
var textReader = Utf8BufferTextReader.Get(payload);
try
var reader = new Utf8JsonReader(in payload, isFinalBlock: true, state: default);
reader.CheckRead();
reader.EnsureObjectStart();
string protocol = null;
int? protocolVersion = null;
while (reader.CheckRead())
{
using (var reader = JsonUtils.CreateJsonTextReader(textReader))
if (reader.TokenType == JsonTokenType.PropertyName)
{
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
var memberName = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan;
string protocol = null;
int? protocolVersion = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
if (memberName.SequenceEqual(ProtocolPropertyNameBytes))
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
var memberName = reader.Value.ToString();
switch (memberName)
{
case ProtocolPropertyName:
protocol = JsonUtils.ReadAsString(reader, ProtocolPropertyName);
break;
case ProtocolVersionPropertyName:
protocolVersion = JsonUtils.ReadAsInt32(reader, ProtocolVersionPropertyName);
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON. Message content: {GetPayloadAsString()}");
}
protocol = reader.ReadAsString(ProtocolPropertyNameBytes);
}
if (protocol == null)
else if (memberName.SequenceEqual(ProtocolVersionPropertyNameBytes))
{
throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'. Message content: {GetPayloadAsString()}");
protocolVersion = reader.ReadAsInt32(ProtocolVersionPropertyNameBytes);
}
if (protocolVersion == null)
else
{
throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'. Message content: {GetPayloadAsString()}");
reader.Skip();
}
requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value);
}
else if (reader.TokenType == JsonTokenType.EndObject)
{
break;
}
else
{
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON. Message content: {GetPayloadAsString()}");
}
}
finally
if (protocol == null)
{
Utf8BufferTextReader.Return(textReader);
throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'. Message content: {GetPayloadAsString()}");
}
if (protocolVersion == null)
{
throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'. Message content: {GetPayloadAsString()}");
}
requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value);
// For error messages, we want to print the payload as text
string GetPayloadAsString()

View File

@ -25,6 +25,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Equal(version, deserializedMessage.Version);
}
[Fact]
public void ParsingHandshakeRequestMessageSuccessForValidMessageWithMultipleSegments()
{
var message = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent("{\"protocol\":\"json\",\"version\":1}\u001e");
Assert.True(HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage));
Assert.Equal("json", deserializedMessage.Protocol);
Assert.Equal(1, deserializedMessage.Version);
}
[Theory]
[InlineData("{\"error\":\"dummy\"}\u001e", "dummy")]
[InlineData("{\"error\":\"\"}\u001e", "")]
@ -38,6 +49,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Equal(error, response.Error);
}
[Fact]
public void ParsingHandshakeResponseMessageSuccessForValidMessageWithMultipleSegments()
{
var message = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent("{\"error\":\"dummy\"}\u001e");
Assert.True(HandshakeProtocol.TryParseResponseMessage(ref message, out var response));
Assert.Equal("dummy", response.Error);
}
[Theory]
[InlineData("{\"error\":\"\",\"minorVersion\":34}\u001e", 34)]
[InlineData("{\"error\":\"flump flump flump\",\"minorVersion\":112}\u001e", 112)]
@ -58,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
}
[Theory]
[InlineData("42\u001e", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
[InlineData("42\u001e", "Unexpected JSON Token Type 'Number'. Expected a JSON Object.")]
[InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
[InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("{}\u001e", "Missing required property 'protocol'. Message content: {}")]
@ -66,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{\"protocol\":\"json\"}\u001e", "Missing required property 'version'. Message content: {\"protocol\":\"json\"}")]
[InlineData("{\"version\":1}\u001e", "Missing required property 'protocol'. Message content: {\"version\":1}")]
[InlineData("{\"type\":4,\"invocationId\":\"42\",\"target\":\"foo\",\"arguments\":{}}\u001e", "Missing required property 'protocol'. Message content: {\"type\":4,\"invocationId\":\"42\",\"target\":\"foo\",\"arguments\":{}}")]
[InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")]
[InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Number.")]
[InlineData("{\"protocol\":null,\"version\":123}\u001e", "Expected 'protocol' to be of type String.")]
public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage)
{
@ -79,7 +100,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
}
[Theory]
[InlineData("42\u001e", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
[InlineData("42\u001e", "Unexpected JSON Token Type 'Number'. Expected a JSON Object.")]
[InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
[InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]

View File

@ -7,6 +7,7 @@
<ItemGroup>
<Compile Include="$(SignalRSharedSourceRoot)BinaryMessageFormatter.cs" Link="BinaryMessageFormatter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)BinaryMessageParser.cs" Link="BinaryMessageParser.cs" />
<Compile Include="$(SharedSourceRoot)Buffers.Testing\**\*.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
ReadOnlySequence<byte> _requestMessage2;
ReadOnlySequence<byte> _requestMessage3;
ReadOnlySequence<byte> _requestMessage4;
ReadOnlySequence<byte> _requestMessage5;
ReadOnlySequence<byte> _responseMessage1;
ReadOnlySequence<byte> _responseMessage2;
@ -31,6 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
_requestMessage2 = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("{\"protocol\":\"\",\"version\":10}\u001e"));
_requestMessage3 = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("{\"protocol\":\"\",\"version\":10,\"unknown\":null}\u001e"));
_requestMessage4 = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("42"));
_requestMessage5 = ReadOnlySequenceFactory.CreateSegments(Encoding.UTF8.GetBytes("{\"protocol\":\"dummy\",\"ver"), Encoding.UTF8.GetBytes("sion\":1}\u001e"));
_responseMessage1 = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("{\"error\":\"dummy\"}\u001e"));
_responseMessage2 = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("{\"error\":\"\"}\u001e"));
@ -85,43 +87,112 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
}
[Benchmark]
public void ParsingHandshakeRequestMessage_ValidMessage1()
=> HandshakeProtocol.TryParseRequestMessage(ref _requestMessage1, out var deserializedMessage);
public void ParsingHandshakeRequestMessage_ValidMessage1() {
var message = _requestMessage1;
if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeRequestMessage_ValidMessage2()
=> HandshakeProtocol.TryParseRequestMessage(ref _requestMessage2, out var deserializedMessage);
{
var message = _requestMessage2;
if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeRequestMessage_ValidMessage3()
=> HandshakeProtocol.TryParseRequestMessage(ref _requestMessage3, out var deserializedMessage);
{
var message = _requestMessage3;
if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeRequestMessage_NotComplete1()
=> HandshakeProtocol.TryParseRequestMessage(ref _requestMessage4, out _);
{
var message = _requestMessage4;
if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeRequestMessage_ValidMessageSegments()
{
var message = _requestMessage5;
if (!HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_ValidMessages1()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage1, out var response);
{
var message = _responseMessage1;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_ValidMessages2()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage2, out var response);
{
var message = _responseMessage2;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_ValidMessages3()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage3, out var response);
{
var message = _responseMessage3;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_ValidMessages4()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage4, out var response);
{
var message = _responseMessage4;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_GivesMinorVersion1()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage5, out var response);
{
var message = _responseMessage5;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
[Benchmark]
public void ParsingHandshakeResponseMessage_GivesMinorVersion2()
=> HandshakeProtocol.TryParseResponseMessage(ref _responseMessage6, out var response);
{
var message = _responseMessage6;
if (!HandshakeProtocol.TryParseResponseMessage(ref message, out var deserializedMessage))
{
throw new Exception();
}
}
}
}

View File

@ -8,6 +8,7 @@
<ItemGroup>
<Compile Include="$(SignalRSharedSourceRoot)BinaryMessageFormatter.cs" Link="BinaryMessageFormatter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)BinaryMessageParser.cs" Link="BinaryMessageParser.cs" />
<Compile Include="$(SharedSourceRoot)Buffers.Testing\**\*.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -70,22 +70,22 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark]
public void ParsingNegotiateResponseMessageSuccessForValid1()
=> NegotiateProtocol.ParseResponse(new MemoryStream(_responseData1));
=> NegotiateProtocol.ParseResponse(_responseData1);
[Benchmark]
public void ParsingNegotiateResponseMessageSuccessForValid2()
=> NegotiateProtocol.ParseResponse(new MemoryStream(_responseData2));
=> NegotiateProtocol.ParseResponse(_responseData2);
[Benchmark]
public void ParsingNegotiateResponseMessageSuccessForValid3()
=> NegotiateProtocol.ParseResponse(new MemoryStream(_responseData3));
=> NegotiateProtocol.ParseResponse(_responseData3);
[Benchmark]
public void ParsingNegotiateResponseMessageSuccessForValid4()
=> NegotiateProtocol.ParseResponse(new MemoryStream(_responseData4));
=> NegotiateProtocol.ParseResponse(_responseData4);
[Benchmark]
public void ParsingNegotiateResponseMessageSuccessForValid5()
=> NegotiateProtocol.ParseResponse(new MemoryStream(_responseData5));
=> NegotiateProtocol.ParseResponse(_responseData5);
}
}