From d7e19c429c9f93a995ea957a8856ed278d90dba5 Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Mon, 16 Sep 2019 14:21:59 -0700 Subject: [PATCH] SignalR ConnectionToken/ConnectionAddress feature (#13773) --- .../HttpConnectionTests.Negotiate.cs | 87 ++++++++++++++ .../Client/test/UnitTests/ResponseUtils.cs | 4 +- .../src/HttpConnection.cs | 19 ++- ...Core.Http.Connections.Common.netcoreapp.cs | 1 + ....Http.Connections.Common.netstandard2.0.cs | 1 + .../src/NegotiateProtocol.cs | 21 ++++ .../src/NegotiationResponse.cs | 1 + .../src/Internal/HttpConnectionContext.cs | 14 ++- .../src/Internal/HttpConnectionDispatcher.cs | 96 +++++++++------ .../src/Internal/HttpConnectionManager.cs | 18 ++- .../test/HttpConnectionDispatcherTests.cs | 112 ++++++++++++------ .../test/HttpConnectionManagerTests.cs | 55 ++++++++- .../test/NegotiateProtocolTests.cs | 27 +++-- 13 files changed, 354 insertions(+), 102 deletions(-) diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 47a16a48d9..5abbde0313 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -36,6 +36,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(connectionId: string.Empty), "Invalid connection id."); } + [Fact] + public Task NegotiateResponseWithNegotiateVersionRequiresConnectionToken() + { + return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(negotiateVersion: 1, connectionToken: null), "Invalid negotiation response received."); + } + [Fact] public Task ConnectionCannotBeStartedIfNoCommonTransportsBetweenClientAndServer() { @@ -156,6 +162,87 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); } + [Fact] + public async Task ConnectionIdGetsSetWithNegotiateProtocolGreaterThanZero() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + negotiateVersion = 1, + connectionToken = "different-id", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?negotiateVersion=1&id=different-id", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + } + + [Fact] + public async Task ConnectionTokenFieldIsIgnoredForNegotiateIdLessThanOne() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + connectionToken = "different-id", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + } + [Fact] public async Task NegotiateThatReturnsUrlGetFollowed() { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs index 58d5fbb53c..33dddf6aab 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs @@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } public static string CreateNegotiationContent(string connectionId = "00000000-0000-0000-0000-000000000000", - HttpTransportType? transportTypes = null) + HttpTransportType? transportTypes = null, string connectionToken = "connection-token", int negotiateVersion = 0) { var availableTransports = new List(); @@ -92,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - return JsonConvert.SerializeObject(new { connectionId, availableTransports }); + return JsonConvert.SerializeObject(new { connectionId, availableTransports, connectionToken, negotiateVersion }); } } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 77205aecf5..74715c638c 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -42,6 +42,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private readonly HttpConnectionOptions _httpConnectionOptions; private ITransport _transport; private readonly ITransportFactory _transportFactory; + private string _connectionToken; private string _connectionId; private readonly ConnectionLogScope _logScope; private readonly ILoggerFactory _loggerFactory; @@ -342,7 +343,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client } // This should only need to happen once - var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); + var connectUrl = CreateConnectUrl(uri, _connectionToken); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -383,7 +384,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client if (negotiationResponse == null) { negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken); - connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); + connectUrl = CreateConnectUrl(uri, _connectionToken); } Log.StartingTransport(_logger, transportType, connectUrl); @@ -629,7 +630,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private async Task GetNegotiationResponseAsync(Uri uri, CancellationToken cancellationToken) { var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger, cancellationToken); - _connectionId = negotiationResponse.ConnectionId; + // If the negotiationVersion is greater than zero then we know that the negotiation response contains a + // connectionToken that will be required to conenct. Otherwise we just set the connectionId and the + // connectionToken on the client to the same value. + if (negotiationResponse.Version > 0) + { + _connectionId = negotiationResponse.ConnectionId; + _connectionToken = negotiationResponse.ConnectionToken; + } + else + { + _connectionToken = _connectionId = negotiationResponse.ConnectionId; + } + _logScope.ConnectionId = _connectionId; return negotiationResponse; } diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs index fb58827e4b..f557b74f58 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs @@ -34,6 +34,7 @@ namespace Microsoft.AspNetCore.Http.Connections public string AccessToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AvailableTransports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public string ConnectionToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs index fb58827e4b..f557b74f58 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs @@ -34,6 +34,7 @@ namespace Microsoft.AspNetCore.Http.Connections public string AccessToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AvailableTransports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public string ConnectionToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index 1d4c4f28ae..ae69b56cdd 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -15,6 +15,8 @@ namespace Microsoft.AspNetCore.Http.Connections { private const string ConnectionIdPropertyName = "connectionId"; private static JsonEncodedText ConnectionIdPropertyNameBytes = JsonEncodedText.Encode(ConnectionIdPropertyName); + private const string ConnectionTokenPropertyName = "connectionToken"; + private static JsonEncodedText ConnectionTokenPropertyNameBytes = JsonEncodedText.Encode(ConnectionTokenPropertyName); private const string UrlPropertyName = "url"; private static JsonEncodedText UrlPropertyNameBytes = JsonEncodedText.Encode(UrlPropertyName); private const string AccessTokenPropertyName = "accessToken"; @@ -71,6 +73,11 @@ namespace Microsoft.AspNetCore.Http.Connections writer.WriteString(ConnectionIdPropertyNameBytes, response.ConnectionId); } + if (response.Version > 0 && !string.IsNullOrEmpty(response.ConnectionToken)) + { + writer.WriteString(ConnectionTokenPropertyNameBytes, response.ConnectionToken); + } + writer.WriteStartArray(AvailableTransportsPropertyNameBytes); if (response.AvailableTransports != null) @@ -127,6 +134,7 @@ namespace Microsoft.AspNetCore.Http.Connections reader.EnsureObjectStart(); string connectionId = null; + string connectionToken = null; string url = null; string accessToken = null; List availableTransports = null; @@ -151,6 +159,10 @@ namespace Microsoft.AspNetCore.Http.Connections { connectionId = reader.ReadAsString(ConnectionIdPropertyName); } + else if (reader.ValueTextEquals(ConnectionTokenPropertyNameBytes.EncodedUtf8Bytes)) + { + connectionToken = reader.ReadAsString(ConnectionTokenPropertyName); + } else if (reader.ValueTextEquals(NegotiateVersionPropertyNameBytes.EncodedUtf8Bytes)) { version = reader.ReadAsInt32(NegotiateVersionPropertyName).GetValueOrDefault(); @@ -202,6 +214,14 @@ namespace Microsoft.AspNetCore.Http.Connections throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); } + if (version > 0) + { + if (connectionToken == null) + { + throw new InvalidDataException($"Missing required property '{ConnectionTokenPropertyName}'."); + } + } + if (availableTransports == null) { throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); @@ -211,6 +231,7 @@ namespace Microsoft.AspNetCore.Http.Connections return new NegotiationResponse { ConnectionId = connectionId, + ConnectionToken = connectionToken, Url = url, AccessToken = accessToken, AvailableTransports = availableTransports, diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index cd21b6cb26..69810a5a71 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Http.Connections public string Url { get; set; } public string AccessToken { get; set; } public string ConnectionId { get; set; } + public string ConnectionToken { get; set; } public int Version { get; set; } public IList AvailableTransports { get; set; } public string Error { get; set; } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 6e21d7a665..6d3fe467e9 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -48,11 +48,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. /// The caller is expected to set the and pipes manually. /// - /// + /// + /// /// - public HttpConnectionContext(string id, ILogger logger) + public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger) { - ConnectionId = id; + ConnectionId = connectionId; + ConnectionToken = connectionToken; LastSeenUtc = DateTime.UtcNow; // The default behavior is that both formats are supported. @@ -74,8 +76,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Features.Set(this); } - public HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null) - : this(id, logger) + internal HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null) + : this(id, null, logger) { Transport = transport; Application = application; @@ -113,6 +115,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal public override string ConnectionId { get; set; } + internal string ConnectionToken { get; set; } + public override IFeatureCollection Features { get; } public ClaimsPrincipal User { get; set; } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 983b1270f6..7bd4acc682 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -59,7 +59,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // Create the log scope and attempt to pass the Connection ID to it so as many logs as possible contain // the Connection ID metadata. If this is the negotiate request then the Connection ID for the scope will // be set a little later. - var logScope = new ConnectionLogScope(GetConnectionId(context)); + + HttpConnectionContext connectionContext = null; + var connectionToken = GetConnectionToken(context); + if (connectionToken != null) + { + _manager.TryGetConnection(GetConnectionToken(context), out connectionContext); + } + + var logScope = new ConnectionLogScope(connectionContext?.ConnectionId); using (_logger.BeginScope(logScope)) { if (HttpMethods.IsPost(context.Request.Method)) @@ -279,13 +287,29 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) { context.Response.ContentType = "application/json"; + string error = null; + int clientProtocolVersion = 0; + if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion)) + { + // Set the negotiate response to the protocol we use. + var queryStringVersionValue = queryStringVersion.ToString(); + if (!int.TryParse(queryStringVersionValue, out clientProtocolVersion)) + { + error = $"The client requested an invalid protocol version '{queryStringVersionValue}'"; + Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue); + } + } // Establish the connection - var connection = CreateConnection(options); + HttpConnectionContext connection = null; + if (error == null) + { + connection = CreateConnection(options, clientProtocolVersion); + } // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. - logScope.ConnectionId = connection.ConnectionId; + logScope.ConnectionId = connection?.ConnectionId; // Don't use thread static instance here because writer is used with async var writer = new MemoryBufferWriter(); @@ -293,7 +317,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal try { // Get the bytes for the connection id - WriteNegotiatePayload(writer, connection.ConnectionId, context, options); + WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error); Log.NegotiationRequest(_logger); @@ -307,38 +331,34 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } } - private void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionDispatcherOptions options) + private void WriteNegotiatePayload(IBufferWriter writer, string connectionId, string connectionToken, HttpContext context, HttpConnectionDispatcherOptions options, + int clientProtocolVersion, string error) { var response = new NegotiationResponse(); - if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion)) + if (!string.IsNullOrEmpty(error)) { - // Set the negotiate response to the protocol we use. - var queryStringVersionValue = queryStringVersion.ToString(); - if (int.TryParse(queryStringVersionValue, out var clientProtocolVersion)) + response.Error = error; + NegotiateProtocol.WriteResponse(response, writer); + return; + } + + if (clientProtocolVersion > 0) + { + if (clientProtocolVersion < options.MinimumProtocolVersion) { - if (clientProtocolVersion < options.MinimumProtocolVersion) - { - response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; - Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); - NegotiateProtocol.WriteResponse(response, writer); - return; - } - else if (clientProtocolVersion > _protocolVersion) - { - response.Version = _protocolVersion; - } - else - { - response.Version = clientProtocolVersion; - } + response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); + NegotiateProtocol.WriteResponse(response, writer); + return; + } + else if (clientProtocolVersion > _protocolVersion) + { + response.Version = _protocolVersion; } else { - response.Error = $"The client requested an invalid protocol version '{queryStringVersionValue}'"; - Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue); - NegotiateProtocol.WriteResponse(response, writer); - return; + response.Version = clientProtocolVersion; } } else if (options.MinimumProtocolVersion > 0) @@ -350,6 +370,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } response.ConnectionId = connectionId; + response.ConnectionToken = connectionToken; response.AvailableTransports = new List(); if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features)) @@ -375,7 +396,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return features.Get() != null; } - private static string GetConnectionId(HttpContext context) => context.Request.Query["id"]; + private static string GetConnectionToken(HttpContext context) => context.Request.Query["id"]; private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOptions options) { @@ -648,9 +669,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private async Task GetConnectionAsync(HttpContext context) { - var connectionId = GetConnectionId(context); + var connectionToken = GetConnectionToken(context); - if (StringValues.IsNullOrEmpty(connectionId)) + if (StringValues.IsNullOrEmpty(connectionToken)) { // There's no connection ID: bad request context.Response.StatusCode = StatusCodes.Status400BadRequest; @@ -659,7 +680,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return null; } - if (!_manager.TryGetConnection(connectionId, out var connection)) + if (!_manager.TryGetConnection(connectionToken, out var connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -674,15 +695,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // This is only used for WebSockets connections, which can connect directly without negotiating private async Task GetOrCreateConnectionAsync(HttpContext context, HttpConnectionDispatcherOptions options) { - var connectionId = GetConnectionId(context); + var connectionToken = GetConnectionToken(context); HttpConnectionContext connection; // There's no connection id so this is a brand new connection - if (StringValues.IsNullOrEmpty(connectionId)) + if (StringValues.IsNullOrEmpty(connectionToken)) { connection = CreateConnection(options); } - else if (!_manager.TryGetConnection(connectionId, out connection)) + else if (!_manager.TryGetConnection(connectionToken, out connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -693,12 +714,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return connection; } - private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options) + private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0) { var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - - return _manager.CreateConnection(transportPipeOptions, appPipeOptions); + return _manager.CreateConnection(transportPipeOptions, appPipeOptions, clientProtocolVersion); } private class EmptyServiceProvider : IServiceProvider diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index dda35866a4..4a97681fc0 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -78,18 +78,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal /// Creates a connection without Pipes setup to allow saving allocations until Pipes are needed. /// /// - internal HttpConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions) + internal HttpConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions, int negotiateVersion = 0) { + string connectionToken; var id = MakeNewConnectionId(); + if (negotiateVersion > 0) + { + connectionToken = MakeNewConnectionId(); + } + else + { + connectionToken = id; + } Log.CreatedNewConnection(_logger, id); var connectionTimer = HttpConnectionsEventSource.Log.ConnectionStart(id); - var connection = new HttpConnectionContext(id, _connectionLogger); + var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger); var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); connection.Transport = pair.Application; connection.Application = pair.Transport; - _connections.TryAdd(id, (connection, connectionTimer)); + _connections.TryAdd(connectionToken, (connection, connectionTimer)); + return connection; } @@ -205,7 +215,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { // Remove it from the list after disposal so that's it's easy to see // connections that might be in a hung state via the connections list - RemoveConnection(connection.ConnectionId); + RemoveConnection(connection.ConnectionToken); } } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 1be53a1dcd..fe5ae6dda8 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class HttpConnectionDispatcherTests : VerifiableLoggedTest { [Fact] - public async Task NegotiateReservesConnectionIdAndReturnsIt() + public async Task NegotiateVersionZeroReservesConnectionIdAndReturnsIt() { using (StartVerifiableLog()) { @@ -55,8 +55,35 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var connectionId = negotiateResponse.Value("connectionId"); - Assert.True(manager.TryGetConnection(connectionId, out var connectionContext)); + var connectionToken = negotiateResponse.Value("connectionToken"); + Assert.Null(connectionToken); + Assert.NotNull(connectionId); + } + } + + [Fact] + public async Task NegotiateReservesConnectionTokenAndConnectionIdAndReturnsIt() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + var connectionId = negotiateResponse.Value("connectionId"); + var connectionToken = negotiateResponse.Value("connectionToken"); + Assert.True(manager.TryGetConnection(connectionToken, out var connectionContext)); Assert.Equal(connectionId, connectionContext.ConnectionId); + Assert.NotEqual(connectionId, connectionToken); } } @@ -75,12 +102,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; await dispatcher.ExecuteNegotiateAsync(context, options); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); - var connectionId = negotiateResponse.Value("connectionId"); - context.Request.QueryString = context.Request.QueryString.Add("id", connectionId); - Assert.True(manager.TryGetConnection(connectionId, out var connection)); + var connectionToken = negotiateResponse.Value("connectionToken"); + context.Request.QueryString = context.Request.QueryString.Add("id", connectionToken); + Assert.True(manager.TryGetConnection(connectionToken, out var connection)); // Fake actual connection after negotiate to populate the pipes on the connection await dispatcher.ExecuteAsync(context, options, c => Task.CompletedTask); @@ -112,7 +140,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; context.Response.Body = ms; context.Request.QueryString = new QueryString("?negotiateVersion=Invalid"); - var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; await dispatcher.ExecuteNegotiateAsync(context, options); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -141,7 +168,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; context.Response.Body = ms; context.Request.QueryString = new QueryString(""); - var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4, MinimumProtocolVersion = 1 }; await dispatcher.ExecuteNegotiateAsync(context, options); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -184,7 +210,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -225,6 +252,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = transports }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -263,6 +291,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "GET"; var values = new Dictionary(); values["id"] = "unknown"; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; SetTransport(context, transportType); @@ -299,6 +328,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = "unknown"; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -335,7 +365,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -374,6 +405,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = connection.ConnectionId; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -413,7 +445,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -474,7 +507,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -540,7 +574,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -603,6 +638,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = connection.ConnectionId; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -672,7 +708,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; values["another"] = "value"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -720,8 +757,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connectionHttpContext = connection.GetHttpContext(); Assert.NotNull(connectionHttpContext); - Assert.Equal(2, connectionHttpContext.Request.Query.Count); - Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]); + Assert.Equal(3, connectionHttpContext.Request.Query.Count); Assert.Equal("value", connectionHttpContext.Request.Query["another"]); Assert.Equal(3, connectionHttpContext.Request.Headers.Count); @@ -765,6 +801,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests services.AddSingleton(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); SetTransport(context, transportType); @@ -807,7 +844,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -834,6 +872,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests services.AddSingleton(); context.Request.Path = "/foo"; context.Request.Method = "POST"; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); @@ -905,6 +944,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); var context = MakeRequest("/foo", connection); + SetTransport(context, HttpTransportType.ServerSentEvents); var services = new ServiceCollection(); @@ -916,7 +956,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1280,7 +1320,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await task; Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1321,7 +1361,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await task; Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1423,10 +1463,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; - var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1511,7 +1551,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; var ms = new MemoryStream(); deleteContext.Response.Body = ms; @@ -1554,7 +1594,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; Assert.False(pollTask.IsCompleted); @@ -1572,7 +1612,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal("text/plain", deleteContext.Response.ContentType); // Verify the connection was removed from the manager - Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out _)); } } @@ -1602,7 +1642,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); @@ -1620,7 +1660,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await connection.DisposeAndRemoveTask.OrTimeout(); // Verify the connection was removed from the manager - Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out _)); } } @@ -1640,6 +1680,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = HttpTransportType.WebSockets }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -1696,7 +1737,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1752,7 +1794,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1805,7 +1848,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1867,7 +1911,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await pollTask.OrTimeout(); Assert.Equal(StatusCodes.Status500InternalServerError, pollContext.Response.StatusCode); - Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out var _)); } } @@ -1890,7 +1934,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -1912,14 +1957,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } - private static DefaultHttpContext MakeRequest(string path, ConnectionContext connection, string format = null) + private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, string format = null) { var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); context.Request.Path = path; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; if (format != null) { values["format"] = format; diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index 5c30a490f7..ade605b08a 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -131,7 +131,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.ConnectionId); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); } } @@ -143,13 +143,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); - var transport = connection.Transport; Assert.NotNull(connection.ConnectionId); + Assert.NotNull(connection.ConnectionToken); Assert.NotNull(transport); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); Assert.Same(transport, newConnection.Transport); } @@ -168,12 +168,55 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.ConnectionId); Assert.NotNull(transport); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); Assert.Same(transport, newConnection.Transport); - connectionManager.RemoveConnection(connection.ConnectionId); - Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out newConnection)); + connectionManager.RemoveConnection(connection.ConnectionToken); + Assert.False(connectionManager.TryGetConnection(connection.ConnectionToken, out newConnection)); + } + } + + [Fact] + public void ConnectionIdAndConnectionTokenAreTheSameForNegotiateVersionZero() + { + using (StartVerifiableLog()) + { + var connectionManager = CreateConnectionManager(LoggerFactory); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default, negotiateVersion: 0); + + var transport = connection.Transport; + + Assert.NotNull(connection.ConnectionId); + Assert.NotNull(transport); + + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); + Assert.Equal(connection.ConnectionId, connection.ConnectionToken); + + } + } + + [Fact] + public void ConnectionIdAndConnectionTokenAreDifferentForNegotiateVersionOne() + { + using (StartVerifiableLog()) + { + var connectionManager = CreateConnectionManager(LoggerFactory); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default, negotiateVersion: 1); + + var transport = connection.Transport; + + Assert.NotNull(connection.ConnectionId); + Assert.NotNull(transport); + + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); + Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out var _)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); + Assert.NotEqual(connection.ConnectionId, connection.ConnectionToken); + } } diff --git a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs index c7d274d0b3..00d803ffdd 100644 --- a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs +++ b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs @@ -13,17 +13,20 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class NegotiateProtocolTests { [Theory] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 0)] - [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null, 0)] - [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null, 0)] - [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token", 0)] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0)] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0)] - [InlineData("{\"negotiateVersion\":123,\"connectionId\":\"123\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 123)] - [InlineData("{\"negotiateVersion\":123,\"negotiateVersion\":321,\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 321)] - [InlineData("{\"ignore\":123,\"negotiateVersion\":123,\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 123)] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123}", "123", new string[0], null, null, 123)] - public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken, int version) + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 0, null)] + [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null, 0, null)] + [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null, 0, null)] + [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token", 0, null)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0, null)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0, null)] + [InlineData("{\"negotiateVersion\":123,\"connectionId\":\"123\",\"connectionToken\":\"789\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 123, "789")] + [InlineData("{\"negotiateVersion\":123,\"negotiateVersion\":321, \"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 321, "789")] + [InlineData("{\"ignore\":123,\"negotiateVersion\":123, \"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123, \"connectionToken\":\"789\"}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionId\":\"123\",\"connectionToken\":\"789\",\"availableTransports\":[]}", "123", new string[0], null, null, 0, "789")] + [InlineData("{\"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123, \"connectionToken\":\"987\"}", "123", new string[0], null, null, 123, "987")] + public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken, int version, string connectionToken) { var responseData = Encoding.UTF8.GetBytes(json); var response = NegotiateProtocol.ParseResponse(responseData); @@ -33,6 +36,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(url, response.Url); Assert.Equal(accessToken, response.AccessToken); Assert.Equal(version, response.Version); + Assert.Equal(connectionToken, response.ConnectionToken); if (response.AvailableTransports != null) { @@ -50,6 +54,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests [InlineData("{\"connectionId\":\"123\",\"availableTransports\":null}", "Unexpected JSON Token Type 'Null'. Expected a JSON Array.")] [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transferFormats\":[]}]}", "Missing required property 'transport'.")] [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\"}]}", "Missing required property 'transferFormats'.")] + [InlineData("{\"connectionId\":\"123\",\"negotiateVersion\":123,\"availableTransports\":[]}", "Missing required property 'connectionToken'.")] public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage) { var responseData = Encoding.UTF8.GetBytes(payload);