Merge pull request #1760 from aspnet/release/2.1

Handshake and negotiation optimization (#1731)
This commit is contained in:
James Newton-King 2018-03-29 17:51:40 +13:00 committed by GitHub
commit a4dd4da7a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 498 additions and 175 deletions

View File

@ -31,14 +31,14 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
private TestUserIdProvider _userIdProvider;
private List<string> _supportedProtocols;
private TestDuplexPipe _pipe;
private ReadResult _handshakeResponseResult;
private ReadResult _handshakeRequestResult;
[GlobalSetup]
public void GlobalSetup()
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
_pipe = new TestDuplexPipe();
@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark]
public async Task SuccessHandshakeAsync()
{
_pipe.AddReadResult(_handshakeResponseResult);
_pipe.AddReadResult(_handshakeRequestResult);
await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _successHubProtocolResolver, _userIdProvider);
}
@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark]
public async Task ErrorHandshakeAsync()
{
_pipe.AddReadResult(_handshakeResponseResult);
_pipe.AddReadResult(_handshakeRequestResult);
await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _failureHubProtocolResolver, _userIdProvider);
}

View File

@ -7,7 +7,7 @@ using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public static class ForceAsyncTaskExtensions
internal static class ForceAsyncTaskExtensions
{
/// <summary>
/// Returns an awaitable/awaiter that will ensure the continuation is executed
@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
}
}
public struct ForceAsyncAwaiter : ICriticalNotifyCompletion
internal struct ForceAsyncAwaiter : ICriticalNotifyCompletion
{
private readonly Task _task;
@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
}
}
public struct ForceAsyncAwaiter<T> : ICriticalNotifyCompletion
internal struct ForceAsyncAwaiter<T> : ICriticalNotifyCompletion
{
private readonly Task<T> _task;

View File

@ -2,25 +2,40 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
namespace Microsoft.AspNetCore.SignalR.Internal
{
public static class JsonUtils
internal static class JsonUtils
{
internal static JsonTextReader CreateJsonTextReader(Utf8BufferTextReader textReader)
internal static JsonTextReader CreateJsonTextReader(TextReader textReader)
{
var reader = new JsonTextReader(textReader);
reader.ArrayPool = JsonArrayPool<char>.Shared;
// Don't close the output, Utf8BufferTextReader is resettable
// Don't close the input, leave closing to the caller
reader.CloseInput = false;
return reader;
}
internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter)
{
var writer = new JsonTextWriter(textWriter);
// Don't close the output, leave closing to the caller
writer.CloseOutput = false;
// SignalR will always write a complete JSON response
// This setting will prevent an error during writing be hidden by another error writing on dispose
writer.AutoCompleteOnClose = false;
return writer;
}
public static JObject GetObject(JToken token)
{
if (token == null || token.Type != JTokenType.Object)
@ -82,6 +97,22 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return tokenType.ToString();
}
public static void EnsureObjectStart(JsonTextReader reader)
{
if (reader.TokenType != JsonToken.StartObject)
{
throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Object.");
}
}
public static void EnsureArrayStart(JsonTextReader reader)
{
if (reader.TokenType != JsonToken.StartArray)
{
throw new InvalidDataException($"Unexpected JSON Token Type '{GetTokenString(reader.TokenType)}'. Expected a JSON Array.");
}
}
public static int? ReadAsInt32(JsonTextReader reader, string propertyName)
{
reader.Read();
@ -115,10 +146,32 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
if (!reader.Read())
{
throw new JsonReaderException("Unexpected end when reading JSON");
throw new InvalidDataException("Unexpected end when reading JSON.");
}
return true;
}
private class JsonArrayPool<T> : IArrayPool<T>
{
private readonly ArrayPool<T> _inner;
internal static readonly JsonArrayPool<T> Shared = new JsonArrayPool<T>(ArrayPool<T>.Shared);
public JsonArrayPool(ArrayPool<T> inner)
{
_inner = inner;
}
public T[] Rent(int minimumLength)
{
return _inner.Rent(minimumLength);
}
public void Return(T[] array)
{
_inner.Return(array);
}
}
}
}
}

View File

@ -523,7 +523,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
{
if (!string.IsNullOrEmpty(message.Error))
if (message.Error != null)
{
Log.HandshakeServerError(_logger, message.Error);
throw new HubException(

View File

@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public static class HandshakeProtocol
{
private const string ProtocolPropertyName = "protocol";
private const string ProtocolVersionName = "version";
private const string ProtocolVersionPropertyName = "version";
private const string ErrorPropertyName = "error";
private const string TypePropertyName = "type";
@ -23,12 +23,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
using (var writer = CreateJsonTextWriter(textWriter))
using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
{
writer.WriteStartObject();
writer.WritePropertyName(ProtocolPropertyName);
writer.WriteValue(requestMessage.Protocol);
writer.WritePropertyName(ProtocolVersionName);
writer.WritePropertyName(ProtocolVersionPropertyName);
writer.WriteValue(requestMessage.Version);
writer.WriteEndObject();
writer.Flush();
@ -47,7 +47,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
using (var writer = CreateJsonTextWriter(textWriter))
using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
{
writer.WriteStartObject();
if (!string.IsNullOrEmpty(responseMessage.Error))
@ -68,14 +68,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
TextMessageFormatter.WriteRecordSeparator(output);
}
private static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter)
{
var writer = new JsonTextWriter(textWriter);
writer.CloseOutput = false;
return writer;
}
public static bool TryParseResponseMessage(ref ReadOnlySequence<byte> buffer, out HandshakeResponseMessage responseMessage)
{
if (!TextMessageParser.TryParseMessage(ref buffer, out var payload))
@ -90,19 +82,42 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
using (var reader = JsonUtils.CreateJsonTextReader(textReader))
{
var token = JToken.ReadFrom(reader);
var handshakeJObject = JsonUtils.GetObject(token);
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
// a handshake response does not have a type
// check the incoming message was not any other type of message
var type = JsonUtils.GetOptionalProperty<string>(handshakeJObject, TypePropertyName);
if (!string.IsNullOrEmpty(type))
string error = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
{
throw new InvalidOperationException("Handshake response should not have a 'type' value.");
}
switch (reader.TokenType)
{
case JsonToken.PropertyName:
string memberName = reader.Value.ToString();
var error = JsonUtils.GetOptionalProperty<string>(handshakeJObject, ErrorPropertyName);
responseMessage = new HandshakeResponseMessage(error);
switch (memberName)
{
case TypePropertyName:
// a handshake response does not have a type
// check the incoming message was not any other type of message
throw new InvalidDataException("Handshake response should not have a 'type' value.");
case ErrorPropertyName:
error = JsonUtils.ReadAsString(reader, ErrorPropertyName);
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake response JSON.");
}
};
responseMessage = (error != null) ? new HandshakeResponseMessage(error) : HandshakeResponseMessage.Empty;
return true;
}
}
@ -125,11 +140,51 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
using (var reader = JsonUtils.CreateJsonTextReader(textReader))
{
var token = JToken.ReadFrom(reader);
var handshakeJObject = JsonUtils.GetObject(token);
var protocol = JsonUtils.GetRequiredProperty<string>(handshakeJObject, ProtocolPropertyName);
var protocolVersion = JsonUtils.GetRequiredProperty<int>(handshakeJObject, ProtocolVersionName, JTokenType.Integer);
requestMessage = new HandshakeRequestMessage(protocol, protocolVersion);
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
string protocol = null;
int? protocolVersion = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
string memberName = reader.Value.ToString();
switch (memberName)
{
case ProtocolPropertyName:
protocol = JsonUtils.ReadAsString(reader, ProtocolPropertyName);
break;
case ProtocolVersionPropertyName:
protocolVersion = JsonUtils.ReadAsInt32(reader, ProtocolVersionPropertyName);
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading handshake request JSON.");
}
}
if (protocol == null)
{
throw new InvalidDataException($"Missing required property '{ProtocolPropertyName}'.");
}
if (protocolVersion == null)
{
throw new InvalidDataException($"Missing required property '{ProtocolVersionPropertyName}'.");
}
requestMessage = new HandshakeRequestMessage(protocol, protocolVersion.Value);
}
}
finally

View File

@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public HandshakeResponseMessage(string error)
{
// Note that a response with an empty string for error in the JSON is considered an errored response
Error = error;
}
}

View File

@ -1,31 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
internal class JsonArrayPool<T> : IArrayPool<T>
{
private readonly ArrayPool<T> _inner;
internal static readonly JsonArrayPool<T> Shared = new JsonArrayPool<T>(ArrayPool<T>.Shared);
public JsonArrayPool(ArrayPool<T> inner)
{
_inner = inner;
}
public T[] Rent(int minimumLength)
{
return _inner.Rent(minimumLength);
}
public void Return(T[] array)
{
_inner.Return(array);
}
}
}

View File

@ -113,10 +113,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
JsonUtils.CheckRead(reader);
// We're always parsing a JSON object
if (reader.TokenType != JsonToken.StartObject)
{
throw new InvalidDataException($"Unexpected JSON Token Type '{JsonUtils.GetTokenString(reader.TokenType)}'. Expected a JSON Object.");
}
JsonUtils.EnsureObjectStart(reader);
do
{
@ -344,7 +341,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private void WriteMessageCore(HubMessage message, Stream stream)
{
using (var writer = new JsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true)))
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true)))
{
writer.WriteStartObject();
switch (message)
@ -385,6 +382,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
}
writer.WriteEndObject();
writer.Flush();
}
}

View File

@ -7,6 +7,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\JsonUtils.cs" Link="Internal\JsonUtils.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryPackageVersion)" />

View File

@ -5,6 +5,10 @@
<TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\JsonUtils.cs" Link="Internal\JsonUtils.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Options" Version="$(MicrosoftExtensionsOptionsPackageVersion)" />
<PackageReference Include="StackExchange.Redis.StrongName" Version="$(StackExchangeRedisStrongNamePackageVersion)" />

View File

@ -9,6 +9,7 @@ using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Redis.Internal;
using Microsoft.Extensions.Logging;
@ -214,10 +215,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
byte[] payload;
using (var stream = new LimitArrayPoolWriteStream())
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream)))
{
_serializer.Serialize(writer, message);
await writer.FlushAsync();
writer.Flush();
payload = stream.ToArray();
}

View File

@ -299,7 +299,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http
using (var response = await httpClient.SendAsync(request))
{
response.EnsureSuccessStatusCode();
var negotiateResponse = await ParseNegotiateResponse(response);
var negotiateResponse = NegotiateProtocol.ParseResponse(await response.Content.ReadAsStreamAsync());
Log.ConnectionEstablished(_logger, negotiateResponse.ConnectionId);
return negotiateResponse;
}
@ -312,29 +312,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http
}
}
private static async Task<NegotiationResponse> ParseNegotiateResponse(HttpResponseMessage response)
{
NegotiationResponse negotiationResponse;
using (var reader = new JsonTextReader(new StreamReader(await response.Content.ReadAsStreamAsync())))
{
try
{
negotiationResponse = new JsonSerializer().Deserialize<NegotiationResponse>(reader);
}
catch (Exception ex)
{
throw new FormatException("Invalid negotiation response received.", ex);
}
}
if (negotiationResponse == null)
{
throw new FormatException("Invalid negotiation response received.");
}
return negotiationResponse;
}
private static Uri CreateConnectUrl(Uri url, string connectionId)
{
if (string.IsNullOrWhiteSpace(connectionId))
@ -454,17 +431,5 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http
_logScope.ConnectionId = _connectionId;
return negotiationResponse;
}
private class NegotiationResponse
{
public string ConnectionId { get; set; }
public AvailableTransport[] AvailableTransports { get; set; }
}
private class AvailableTransport
{
public string Transport { get; set; }
public string[] TransferFormats { get; set; }
}
}
}

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class AvailableTransport
{
public string Transport { get; set; }
public List<string> TransferFormats { get; set; }
}
}

View File

@ -0,0 +1,196 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public static class NegotiateProtocol
{
private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
private const string ConnectionIdPropertyName = "connectionId";
private const string AvailableTransportsPropertyName = "availableTransports";
private const string TransportPropertyName = "transport";
private const string TransferFormatsPropertyName = "transferFormats";
public static void WriteResponse(NegotiationResponse response, Stream output)
{
using (var jsonWriter = JsonUtils.CreateJsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true)))
{
jsonWriter.WriteStartObject();
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
jsonWriter.WriteValue(response.ConnectionId);
jsonWriter.WritePropertyName(AvailableTransportsPropertyName);
jsonWriter.WriteStartArray();
foreach (var availableTransport in response.AvailableTransports)
{
jsonWriter.WriteStartObject();
jsonWriter.WritePropertyName(TransportPropertyName);
jsonWriter.WriteValue(availableTransport.Transport);
jsonWriter.WritePropertyName(TransferFormatsPropertyName);
jsonWriter.WriteStartArray();
foreach (var transferFormat in availableTransport.TransferFormats)
{
jsonWriter.WriteValue(transferFormat);
}
jsonWriter.WriteEndArray();
jsonWriter.WriteEndObject();
}
jsonWriter.WriteEndArray();
jsonWriter.WriteEndObject();
jsonWriter.Flush();
}
}
public static NegotiationResponse ParseResponse(Stream content)
{
try
{
using (var reader = JsonUtils.CreateJsonTextReader(new StreamReader(content)))
{
JsonUtils.CheckRead(reader);
JsonUtils.EnsureObjectStart(reader);
string connectionId = null;
List<AvailableTransport> availableTransports = null;
var completed = false;
while (!completed && JsonUtils.CheckRead(reader))
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
var memberName = reader.Value.ToString();
switch (memberName)
{
case ConnectionIdPropertyName:
connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName);
break;
case AvailableTransportsPropertyName:
JsonUtils.CheckRead(reader);
JsonUtils.EnsureArrayStart(reader);
availableTransports = new List<AvailableTransport>();
while (JsonUtils.CheckRead(reader))
{
if (reader.TokenType == JsonToken.StartObject)
{
availableTransports.Add(ParseAvailableTransport(reader));
}
else if (reader.TokenType == JsonToken.EndArray)
{
break;
}
}
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading negotiation response JSON.");
}
}
if (connectionId == null)
{
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
}
if (availableTransports == null)
{
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
}
return new NegotiationResponse
{
ConnectionId = connectionId,
AvailableTransports = availableTransports
};
}
}
catch (Exception ex)
{
throw new InvalidDataException("Invalid negotiation response received.", ex);
}
}
private static AvailableTransport ParseAvailableTransport(JsonTextReader reader)
{
var availableTransport = new AvailableTransport();
while (JsonUtils.CheckRead(reader))
{
switch (reader.TokenType)
{
case JsonToken.PropertyName:
string memberName = reader.Value.ToString();
switch (memberName)
{
case TransportPropertyName:
availableTransport.Transport = JsonUtils.ReadAsString(reader, TransportPropertyName);
break;
case TransferFormatsPropertyName:
JsonUtils.CheckRead(reader);
JsonUtils.EnsureArrayStart(reader);
bool completed = false;
availableTransport.TransferFormats = new List<string>();
while (!completed && JsonUtils.CheckRead(reader))
{
switch (reader.TokenType)
{
case JsonToken.String:
availableTransport.TransferFormats.Add(reader.Value.ToString());
break;
case JsonToken.EndArray:
completed = true;
break;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading transfer formats JSON.");
}
}
break;
default:
reader.Skip();
break;
}
break;
case JsonToken.EndObject:
if (availableTransport.Transport == null)
{
throw new InvalidDataException($"Missing required property '{TransportPropertyName}'.");
}
if (availableTransport.TransferFormats == null)
{
throw new InvalidDataException($"Missing required property '{TransferFormatsPropertyName}'.");
}
return availableTransport;
default:
throw new InvalidDataException($"Unexpected token '{reader.TokenType}' when reading available transport JSON.");
}
}
throw new InvalidDataException("Unexpected end when reading JSON.");
}
}
}

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class NegotiationResponse
{
public string ConnectionId { get; set; }
public List<AvailableTransport> AvailableTransports { get; set; }
}
}

View File

@ -5,7 +5,13 @@
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>Microsoft.AspNetCore.Sockets</RootNamespace>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\JsonUtils.cs" Link="Internal\JsonUtils.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Connections.Abstractions" Version="$(MicrosoftAspNetCoreConnectionsAbstractionsPackageVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Buffers" Version="$(SystemBuffersPackageVersion)" />
</ItemGroup>
</Project>

View File

@ -24,6 +24,27 @@ namespace Microsoft.AspNetCore.Sockets
{
public partial class HttpConnectionDispatcher
{
private static readonly AvailableTransport _webSocketAvailableTransport =
new AvailableTransport
{
Transport = nameof(TransportType.WebSockets),
TransferFormats = new List<string> { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) }
};
private static readonly AvailableTransport _serverSentEventsAvailableTransport =
new AvailableTransport
{
Transport = nameof(TransportType.ServerSentEvents),
TransferFormats = new List<string> { nameof(TransferFormat.Text) }
};
private static readonly AvailableTransport _longPollingAvailableTransport =
new AvailableTransport
{
Transport = nameof(TransportType.LongPolling),
TransferFormats = new List<string> { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) }
};
private readonly HttpConnectionManager _manager;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
@ -368,7 +389,7 @@ namespace Microsoft.AspNetCore.Sockets
logScope.ConnectionId = connection.ConnectionId;
// Get the bytes for the connection id
var negotiateResponseBuffer = Encoding.UTF8.GetBytes(GetNegotiatePayload(connection.ConnectionId, context, options));
var negotiateResponseBuffer = GetNegotiatePayload(connection.ConnectionId, context, options);
Log.NegotiationRequest(_logger);
@ -377,40 +398,31 @@ namespace Microsoft.AspNetCore.Sockets
return context.Response.Body.WriteAsync(negotiateResponseBuffer, 0, negotiateResponseBuffer.Length);
}
private static string GetNegotiatePayload(string connectionId, HttpContext context, HttpConnectionOptions options)
private static byte[] GetNegotiatePayload(string connectionId, HttpContext context, HttpConnectionOptions options)
{
var sb = new StringBuilder();
using (var jsonWriter = new JsonTextWriter(new StringWriter(sb)))
NegotiationResponse response = new NegotiationResponse();
response.ConnectionId = connectionId;
response.AvailableTransports = new List<AvailableTransport>();
if ((options.Transports & TransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features))
{
jsonWriter.WriteStartObject();
jsonWriter.WritePropertyName("connectionId");
jsonWriter.WriteValue(connectionId);
jsonWriter.WritePropertyName("availableTransports");
jsonWriter.WriteStartArray();
if (ServerHasWebSockets(context.Features))
{
if ((options.Transports & TransportType.WebSockets) != 0)
{
WriteTransport(jsonWriter, nameof(TransportType.WebSockets), TransferFormat.Text | TransferFormat.Binary);
}
}
if ((options.Transports & TransportType.ServerSentEvents) != 0)
{
WriteTransport(jsonWriter, nameof(TransportType.ServerSentEvents), TransferFormat.Text);
}
if ((options.Transports & TransportType.LongPolling) != 0)
{
WriteTransport(jsonWriter, nameof(TransportType.LongPolling), TransferFormat.Text | TransferFormat.Binary);
}
jsonWriter.WriteEndArray();
jsonWriter.WriteEndObject();
response.AvailableTransports.Add(_webSocketAvailableTransport);
}
return sb.ToString();
if ((options.Transports & TransportType.ServerSentEvents) != 0)
{
response.AvailableTransports.Add(_serverSentEventsAvailableTransport);
}
if ((options.Transports & TransportType.LongPolling) != 0)
{
response.AvailableTransports.Add(_longPollingAvailableTransport);
}
MemoryStream ms = new MemoryStream();
NegotiateProtocol.WriteResponse(response, ms);
return ms.ToArray();
}
private static bool ServerHasWebSockets(IFeatureCollection features)
@ -418,27 +430,6 @@ namespace Microsoft.AspNetCore.Sockets
return features.Get<IHttpWebSocketFeature>() != null;
}
private static void WriteTransport(JsonWriter writer, string transportName, TransferFormat supportedTransferFormats)
{
writer.WriteStartObject();
writer.WritePropertyName("transport");
writer.WriteValue(transportName);
writer.WritePropertyName("transferFormats");
writer.WriteStartArray();
if ((supportedTransferFormats & TransferFormat.Binary) != 0)
{
writer.WriteValue(nameof(TransferFormat.Binary));
}
if ((supportedTransferFormats & TransferFormat.Text) != 0)
{
writer.WriteValue(nameof(TransferFormat.Text));
}
writer.WriteEndArray();
writer.WriteEndObject();
}
private static string GetConnectionId(HttpContext context) => context.Request.Query["id"];
private async Task ProcessSend(HttpContext context, HttpConnectionOptions options)

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Client.Tests;
@ -25,13 +26,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
[InlineData("Not Json")]
public Task StartThrowsFormatExceptionIfNegotiationResponseIsInvalid(string negotiatePayload)
{
return RunInvalidNegotiateResponseTest<FormatException>(negotiatePayload, "Invalid negotiation response received.");
return RunInvalidNegotiateResponseTest<InvalidDataException>(negotiatePayload, "Invalid negotiation response received.");
}
[Fact]
public Task StartThrowsFormatExceptionIfNegotiationResponseHasNoConnectionId()
{
return RunInvalidNegotiateResponseTest<FormatException>(ResponseUtils.CreateNegotiationContent(connectionId: null), "Invalid connection id.");
return RunInvalidNegotiateResponseTest<FormatException>(ResponseUtils.CreateNegotiationContent(connectionId: string.Empty), "Invalid connection id.");
}
[Fact]

View File

@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Theory]
[InlineData("{\"protocol\":\"dummy\",\"version\":1}\u001e", "dummy", 1)]
[InlineData("{\"protocol\":\"\",\"version\":10}\u001e", "", 10)]
[InlineData("{\"protocol\":null,\"version\":123}\u001e", null, 123)]
[InlineData("{\"protocol\":\"\",\"version\":10,\"unknown\":null}\u001e", "", 10)]
public void ParsingHandshakeRequestMessageSuccessForValidMessages(string json, string protocol, int version)
{
var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(json));
@ -29,8 +29,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Theory]
[InlineData("{\"error\":\"dummy\"}\u001e", "dummy")]
[InlineData("{\"error\":\"\"}\u001e", "")]
[InlineData("{\"error\":null}\u001e", null)]
[InlineData("{}\u001e", null)]
[InlineData("{\"unknown\":null}\u001e", null)]
public void ParsingHandshakeResponseMessageSuccessForValidMessages(string json, string error)
{
var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(json));
@ -55,7 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
[InlineData("{\"protocol\":\"json\"}\u001e", "Missing required property 'version'.")]
[InlineData("{\"version\":1}\u001e", "Missing required property 'protocol'.")]
[InlineData("{\"protocol\":null,\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")]
[InlineData("{\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")]
[InlineData("{\"protocol\":null,\"version\":123}\u001e", "Expected 'protocol' to be of type String.")]
public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage)
{
var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(payload));
@ -71,12 +72,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
[InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
[InlineData("{\"error\":null}\u001e", "Expected 'error' to be of type String.")]
public void ParsingHandshakeResponseMessageThrowsForInvalidMessages(string payload, string expectedMessage)
{
var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(payload));
var exception = Assert.Throws<InvalidDataException>(() =>
HandshakeProtocol.TryParseRequestMessage(ref message, out _));
HandshakeProtocol.TryParseResponseMessage(ref message, out _));
Assert.Equal(expectedMessage, exception.Message);
}

View File

@ -141,7 +141,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
}
[Theory]
[InlineData("", "Error reading JSON.")]
[InlineData("", "Unexpected end when reading JSON.")]
[InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
[InlineData("'foo'", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
@ -177,7 +177,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")]
[InlineData("{'type':3,'invocationId':'42','error':'foo','result':true}", "The 'error' and 'result' properties are mutually exclusive.")]
[InlineData("{'type':3,'invocationId':'42','result':true", "Error reading JSON.")]
[InlineData("{'type':3,'invocationId':'42','result':true", "Unexpected end when reading JSON.")]
public void InvalidMessages(string input, string expectedMessage)
{
input = Frame(input);

View File

@ -0,0 +1,51 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets.Internal;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
{
public class NegotiateProtocolTests
{
[Theory]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0])]
[InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0])]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new [] { "test"})]
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports)
{
var responseData = Encoding.UTF8.GetBytes(json);
var ms = new MemoryStream(responseData);
var response = NegotiateProtocol.ParseResponse(ms);
Assert.Equal(connectionId, response.ConnectionId);
Assert.Equal(availableTransports.Length, response.AvailableTransports.Count);
var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList();
Assert.Equal(availableTransports, responseTransports);
}
[Theory]
[InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("[]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
[InlineData("{\"availableTransports\":[]}", "Missing required property 'connectionId'.")]
[InlineData("{\"connectionId\":123,\"availableTransports\":[]}", "Expected 'connectionId' to be of type String.")]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":null}", "Unexpected JSON Token Type 'Null'. Expected a JSON Array.")]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transferFormats\":[]}]}", "Missing required property 'transport'.")]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\"}]}", "Missing required property 'transferFormats'.")]
public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage)
{
var responseData = Encoding.UTF8.GetBytes(payload);
var ms = new MemoryStream(responseData);
var exception = Assert.Throws<InvalidDataException>(() => NegotiateProtocol.ParseResponse(ms));
Assert.Equal(expectedMessage, exception.InnerException.Message);
}
}
}