diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 14e9d78e5b..6dbd4e032e 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -114,6 +114,71 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task ServerRejectsClientWithOldProtocol() + { + bool ExpectedError(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + var protocol = HubProtocols["json"]; + using (StartServer(out var server, ExpectedError)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/negotiateProtocolVersion12", HttpTransportType.LongPolling); + connectionBuilder.Services.AddSingleton(protocol); + + var connection = connectionBuilder.Build(); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => connection.StartAsync()).OrTimeout(); + Assert.Equal("The client requested version '1', but the server does not support this version.", ex.Message); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Fact] + public async Task ClientCanConnectToServerWithLowerMinimumProtocol() + { + var protocol = HubProtocols["json"]; + using (StartServer(out var server)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/negotiateProtocolVersionNegative", HttpTransportType.LongPolling); + connectionBuilder.Services.AddSingleton(protocol); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().OrTimeout(); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task CanSendAndReceiveMessage(string protocolName, HttpTransportType transportType, string path) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs index 1d7dbd6718..4cbc35c510 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs @@ -69,6 +69,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests endpoints.MapHub("/default-nowebsockets", options => options.Transports = HttpTransportType.LongPolling | HttpTransportType.ServerSentEvents); + endpoints.MapHub("/negotiateProtocolVersion12", options => + { + options.MinimumProtocolVersion = 12; + }); + + endpoints.MapHub("/negotiateProtocolVersionNegative", options => + { + options.MinimumProtocolVersion = -1; + }); + endpoints.MapGet("/generateJwtToken", context => { return context.Response.WriteAsync(GenerateJwtToken()); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 348e33cebf..a1392a3cb5 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -36,6 +36,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(connectionId: string.Empty), "Invalid connection id."); } + [Fact] + public Task NegotiateResponseWithNegotiateVersionRequiresConnectionToken() + { + return RunInvalidNegotiateResponseTest(ResponseUtils.CreateNegotiationContent(negotiateVersion: 1, connectionToken: null), "Invalid negotiation response received."); + } + [Fact] public Task ConnectionCannotBeStartedIfNoCommonTransportsBetweenClientAndServer() { @@ -50,12 +56,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Theory] - [InlineData("http://fakeuri.org/", "http://fakeuri.org/negotiate")] - [InlineData("http://fakeuri.org/?q=1/0", "http://fakeuri.org/negotiate?q=1/0")] - [InlineData("http://fakeuri.org?q=1/0", "http://fakeuri.org/negotiate?q=1/0")] - [InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate")] - [InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate")] - [InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0")] + [InlineData("http://fakeuri.org/", "http://fakeuri.org/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/?q=1/0", "http://fakeuri.org/negotiate?q=1/0&negotiateVersion=1")] + [InlineData("http://fakeuri.org?q=1/0", "http://fakeuri.org/negotiate?q=1/0&negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0&negotiateVersion=1")] public async Task CorrectlyHandlesQueryStringWhenAppendingNegotiateToUrl(string requestedUrl, string expectedNegotiate) { var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); @@ -119,6 +125,124 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); } + [Fact] + public async Task NegotiateCanHaveNewFields() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + } + + [Fact] + public async Task ConnectionIdGetsSetWithNegotiateProtocolGreaterThanZero() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + negotiateVersion = 1, + connectionToken = "different-id", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?id=different-id", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + } + + [Fact] + public async Task ConnectionTokenFieldIsIgnoredForNegotiateIdLessThanOne() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + connectionToken = "different-id", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + } + [Fact] public async Task NegotiateThatReturnsUrlGetFollowed() { @@ -172,8 +296,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); @@ -278,14 +402,76 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); // Delete request Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); } + [Fact] + public async Task NegotiateThatReturnsRedirectUrlDoesNotAddAnotherNegotiateVersionQueryString() + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + var negotiateCount = 0; + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + negotiateCount++; + if (negotiateCount == 1) + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + url = "https://another.domain.url/chat?negotiateVersion=1" + })); + } + else + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + } + })); + } + }); + + testHttpHandler.OnLongPoll((token) => + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); + + return tcs.Task; + }); + + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + }); + } + + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); + } + [Fact] public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand() { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs index 58d5fbb53c..33dddf6aab 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/ResponseUtils.cs @@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } public static string CreateNegotiationContent(string connectionId = "00000000-0000-0000-0000-000000000000", - HttpTransportType? transportTypes = null) + HttpTransportType? transportTypes = null, string connectionToken = "connection-token", int negotiateVersion = 0) { var availableTransports = new List(); @@ -92,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - return JsonConvert.SerializeObject(new { connectionId, availableTransports }); + return JsonConvert.SerializeObject(new { connectionId, availableTransports, connectionToken, negotiateVersion }); } } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs index 06d05da7f5..8144d9d574 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using System.Collections.Generic; using System.Net; @@ -117,7 +120,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); testHttpMessageHandler.OnRequest((request, next, cancellationToken) => { - if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.StartsWith("/?id=")) + if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.Contains("id=")) { deleteCts.Cancel(); return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 852681d963..1fb9ba10aa 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -26,6 +26,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client // Not configurable on purpose, high enough that if we reach here, it's likely // a buggy server private static readonly int _maxRedirects = 100; + private static readonly int _protocolVersionNumber = 1; private static readonly Task _noAccessToken = Task.FromResult(null); private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120); @@ -341,7 +342,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client } // This should only need to happen once - var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); + var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -382,10 +383,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Client if (negotiationResponse == null) { negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken); - connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); + connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } - Log.StartingTransport(_logger, transportType, connectUrl); + Log.StartingTransport(_logger, transportType, uri); await StartTransport(connectUrl, transportType, transferFormat, cancellationToken); break; } @@ -428,8 +429,17 @@ namespace Microsoft.AspNetCore.Http.Connections.Client urlBuilder.Path += "/"; } urlBuilder.Path += "negotiate"; + Uri uri; + if (urlBuilder.Query.Contains("negotiateVersion")) + { + uri = urlBuilder.Uri; + } + else + { + uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); + } - using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri)) + using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { // Corefx changed the default version and High Sierra curlhandler tries to upgrade request request.Version = new Version(1, 1); @@ -466,7 +476,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client throw new FormatException("Invalid connection id."); } - return Utils.AppendQueryString(url, "id=" + connectionId); + return Utils.AppendQueryString(url, $"id={connectionId}"); } private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken) @@ -607,7 +617,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private async Task GetNegotiationResponseAsync(Uri uri, CancellationToken cancellationToken) { var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger, cancellationToken); + // If the negotiationVersion is greater than zero then we know that the negotiation response contains a + // connectionToken that will be required to conenct. Otherwise we just set the connectionId and the + // connectionToken on the client to the same value. _connectionId = negotiationResponse.ConnectionId; + if (negotiationResponse.Version == 0) + { + negotiationResponse.ConnectionToken = _connectionId; + } + _logScope.ConnectionId = _connectionId; return negotiationResponse; } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index aa333fb508..36dca4f808 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -57,6 +57,7 @@ public class HubConnection { private Map streamMap = new ConcurrentHashMap<>(); private TransportEnum transportEnum = TransportEnum.ALL; private String connectionId; + private final int negotiateVersion = 1; private final Logger logger = LoggerFactory.getLogger(HubConnection.class); /** @@ -260,7 +261,7 @@ public class HubConnection { HttpRequest request = new HttpRequest(); request.addHeaders(this.localHeaders); - return httpClient.post(Negotiate.resolveNegotiateUrl(url), request).map((response) -> { + return httpClient.post(Negotiate.resolveNegotiateUrl(url, this.negotiateVersion), request).map((response) -> { if (response.getStatusCode() != 200) { throw new RuntimeException(String.format("Unexpected status code returned from negotiate: %d %s.", response.getStatusCode(), response.getStatusText())); @@ -376,7 +377,6 @@ public class HubConnection { hubConnectionStateLock.lock(); try { hubConnectionState = HubConnectionState.CONNECTED; - this.connectionId = negotiateResponse.getConnectionId(); logger.info("HubConnection started."); resetServerTimeout(); //Don't send pings if we're using long polling. @@ -446,14 +446,16 @@ public class HubConnection { throw new RuntimeException("There were no compatible transports on the server."); } - String finalUrl = url; - if (response.getConnectionId() != null) { - if (url.contains("?")) { - finalUrl = url + "&id=" + response.getConnectionId(); - } else { - finalUrl = url + "?id=" + response.getConnectionId(); - } + String connectionToken = ""; + if (response.getVersion() > 0) { + this.connectionId = response.getConnectionId(); + connectionToken = response.getConnectionToken(); + } else { + connectionToken = this.connectionId = response.getConnectionId(); } + + String finalUrl = Utils.appendQueryString(url, "id=" + connectionToken); + response.setFinalUrl(finalUrl); return Single.just(response); } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java index d63359b90c..73dc0ddf64 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java @@ -4,18 +4,18 @@ package com.microsoft.signalr; class Negotiate { - public static String resolveNegotiateUrl(String url) { + public static String resolveNegotiateUrl(String url, int negotiateVersion) { String negotiateUrl = ""; // Check if we have a query string. If we do then we ignore it for now. int queryStringIndex = url.indexOf('?'); if (queryStringIndex > 0) { - negotiateUrl = url.substring(0, url.indexOf('?')); + negotiateUrl = url.substring(0, queryStringIndex); } else { negotiateUrl = url; } - //Check if the url ends in a / + // Check if the url ends in a / if (negotiateUrl.charAt(negotiateUrl.length() - 1) != '/') { negotiateUrl += "/"; } @@ -24,7 +24,11 @@ class Negotiate { // Add the query string back if it existed. if (queryStringIndex > 0) { - negotiateUrl += url.substring(url.indexOf('?')); + negotiateUrl += url.substring(queryStringIndex); + } + + if (!url.contains("negotiateVersion")) { + negotiateUrl = Utils.appendQueryString(negotiateUrl, "negotiateVersion=" + negotiateVersion); } return negotiateUrl; diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java index f115e9601b..bf09b37578 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java @@ -11,11 +11,13 @@ import com.google.gson.stream.JsonReader; class NegotiateResponse { private String connectionId; + private String connectionToken; private Set availableTransports = new HashSet<>(); private String redirectUrl; private String accessToken; private String error; private String finalUrl; + private int version; public NegotiateResponse(JsonReader reader) { try { @@ -30,6 +32,12 @@ class NegotiateResponse { case "ProtocolVersion": this.error = "Detected an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details."; return; + case "negotiateVersion": + this.version = reader.nextInt(); + break; + case "connectionToken": + this.connectionToken = reader.nextString(); + break; case "url": this.redirectUrl = reader.nextString(); break; @@ -106,6 +114,14 @@ class NegotiateResponse { return finalUrl; } + public int getVersion() { + return version; + } + + public String getConnectionToken() { + return connectionToken; + } + public void setFinalUrl(String url) { this.finalUrl = url; } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Utils.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Utils.java new file mode 100644 index 0000000000..d08c6fb914 --- /dev/null +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Utils.java @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.signalr; + +class Utils { + public static String appendQueryString(String original, String queryStringValue) { + if (original.contains("?")) { + return original + "&" + queryStringValue; + } else { + return original + "?" + queryStringValue; + } + } +} \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index f28adf13e7..3f1d7d5b3f 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -1714,12 +1714,12 @@ class HubConnectionTest { List sentRequests = client.getSentRequests(); assertEquals(1, sentRequests.size()); - assertEquals("http://example.com/negotiate", sentRequests.get(0).getUrl()); + assertEquals("http://example.com/negotiate?negotiateVersion=1", sentRequests.get(0).getUrl()); } @Test public void negotiateThatRedirectsForeverFailsAfter100Tries() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://example.com\"}"))); HubConnection hubConnection = HubConnectionBuilder @@ -1752,7 +1752,7 @@ class HubConnectionTest { @Test public void connectionIdIsAvailableAfterStart() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); @@ -1775,9 +1775,88 @@ class HubConnectionTest { assertNull(hubConnection.getConnectionId()); } + @Test + public void connectionTokenAppearsInQSConnectionIdIsOnConnectionInstance() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"negotiateVersion\": 1," + + "\"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); + assertEquals("http://example.com?id=connection-token-value", transport.getUrl()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + } + + @Test + public void connectionTokenIsIgnoredIfNegotiateVersionIsNotPresentInNegotiateResponse() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); + assertEquals("http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + } + + @Test + public void negotiateVersionIsNotAddedIfAlreadyPresent() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=42", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com?negotiateVersion=42") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); + assertEquals("http://example.com?negotiateVersion=42&id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + } + @Test public void afterSuccessfulNegotiateConnectsWithWebsocketsTransport() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); @@ -1798,7 +1877,7 @@ class HubConnectionTest { @Test public void afterSuccessfulNegotiateConnectsWithLongPollingTransport() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); @@ -1891,7 +1970,7 @@ class HubConnectionTest { @Test public void receivingServerSentEventsTransportFromNegotiateFails() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}]}"))); @@ -1911,7 +1990,7 @@ class HubConnectionTest { @Test public void negotiateThatReturnsErrorThrowsFromStart() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"error\":\"Test error.\"}"))); MockTransport transport = new MockTransport(true); @@ -1928,7 +2007,7 @@ class HubConnectionTest { @Test public void DetectWhenTryingToConnectToClassicSignalRServer() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"Url\":\"/signalr\"," + "\"ConnectionToken\":\"X97dw3uxW4NPPggQsYVcNcyQcuz4w2\"," + "\"ConnectionId\":\"05265228-1e2c-46c5-82a1-6a5bcc3f0143\"," + @@ -1954,9 +2033,9 @@ class HubConnectionTest { @Test public void negotiateRedirectIsFollowed() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\"}"))) - .on("POST", "http://testexample.com/negotiate", + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); @@ -1978,11 +2057,11 @@ class HubConnectionTest { AtomicReference beforeRedirectToken = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", (req) -> { + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { beforeRedirectToken.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}")); }) - .on("POST", "http://testexample.com/negotiate", (req) -> { + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> { token.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); @@ -2018,7 +2097,7 @@ class HubConnectionTest { public void accessTokenProviderIsUsedForNegotiate() { AtomicReference token = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { token.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -2043,11 +2122,13 @@ class HubConnectionTest { public void accessTokenProviderIsOverriddenFromRedirectNegotiate() { AtomicReference token = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"))) - .on("POST", "http://testexample.com/negotiate", (req) -> { + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"))) + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> { token.set(req.getHeaders().get("Authorization")); - return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" - + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"connectionToken\":\"connection-token-value\"," + + "\"negotiateVersion\":1," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); }); MockTransport transport = new MockTransport(true); @@ -2060,7 +2141,7 @@ class HubConnectionTest { hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); - assertEquals("http://testexample.com/?id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); + assertEquals("http://testexample.com/?id=connection-token-value", transport.getUrl()); hubConnection.stop(); assertEquals("Bearer newToken", token.get()); } @@ -2071,14 +2152,14 @@ class HubConnectionTest { AtomicReference beforeRedirectToken = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", (req) -> { + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { beforeRedirectToken.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}")); }) - .on("POST", "http://testexample.com/negotiate", (req) -> { + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> { token.set(req.getHeaders().get("Authorization")); - return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" - + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); }); MockTransport transport = new MockTransport(true); @@ -2112,7 +2193,7 @@ class HubConnectionTest { AtomicInteger redirectCount = new AtomicInteger(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", (req) -> { + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { if (redirectCount.get() == 0) { redirectCount.incrementAndGet(); redirectToken.set(req.getHeaders().get("Authorization")); @@ -2122,7 +2203,7 @@ class HubConnectionTest { return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"secondRedirectToken\"}")); } }) - .on("POST", "http://testexample.com/negotiate", (req) -> { + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> { token.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); @@ -2189,7 +2270,7 @@ class HubConnectionTest { public void headersAreSetAndSentThroughBuilder() { AtomicReference header = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { header.set(req.getHeaders().get("ExampleHeader")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -2214,7 +2295,7 @@ class HubConnectionTest { public void headersAreNotClearedWhenConnectionIsRestarted() { AtomicReference header = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { header.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -2244,12 +2325,12 @@ class HubConnectionTest { AtomicReference afterRedirectHeader = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { beforeRedirectHeader.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"redirectToken\"}\"}")); }) - .on("POST", "http://testexample.com/negotiate", + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> { afterRedirectHeader.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -2287,7 +2368,7 @@ class HubConnectionTest { public void sameHeaderSetTwiceGetsOverwritten() { AtomicReference header = new AtomicReference<>(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> { header.set(req.getHeaders().get("ExampleHeader")); return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -2332,8 +2413,8 @@ class HubConnectionTest { public void hubConnectionCanBeStartedAfterBeingStoppedAndRedirected() { MockTransport mockTransport = new MockTransport(); TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\"}"))) - .on("POST", "http://testexample.com/negotiate", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\"}"))) + .on("POST", "http://testexample.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); HubConnection hubConnection = HubConnectionBuilder @@ -2355,7 +2436,7 @@ class HubConnectionTest { @Test public void non200FromNegotiateThrowsError() { TestHttpClient client = new TestHttpClient() - .on("POST", "http://example.com/negotiate", + .on("POST", "http://example.com/negotiate?negotiateVersion=1", (req) -> Single.just(new HttpResponse(500, "Internal server error", ""))); MockTransport transport = new MockTransport(); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java index 88175d0ac9..1eaa0a00df 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java @@ -15,8 +15,9 @@ import com.google.gson.stream.JsonReader; class NegotiateResponseTest { @Test public void VerifyNegotiateResponse() { - String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + - "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," + + String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"negotiateVersion\": 99, \"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," + "{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}," + "{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"; NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); @@ -26,6 +27,8 @@ class NegotiateResponseTest { assertNull(negotiateResponse.getAccessToken()); assertNull(negotiateResponse.getRedirectUrl()); assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); + assertEquals("connection-token-value", negotiateResponse.getConnectionToken()); + assertEquals(99, negotiateResponse.getVersion()); } @Test @@ -56,4 +59,23 @@ class NegotiateResponseTest { NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); } + + @Test + public void NegotiateResponseWithNegotiateVersion() { + String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"negotiateVersion\": 99}"; + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); + assertEquals(99, negotiateResponse.getVersion()); + } + + @Test + public void NegotiateResponseWithConnectionToken() { + String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"negotiateVersion\": 99, \"connectionToken\":\"connection-token-value\"}"; + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); + assertEquals("connection-token-value", negotiateResponse.getConnectionToken()); + assertEquals(99, negotiateResponse.getVersion()); + } } diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java index eacee30346..1c2a213f6e 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java @@ -14,17 +14,18 @@ import org.junit.jupiter.params.provider.MethodSource; class ResolveNegotiateUrlTest { private static Stream protocols() { return Stream.of( - Arguments.of("http://example.com/hub/", "http://example.com/hub/negotiate"), - Arguments.of("http://example.com/hub", "http://example.com/hub/negotiate"), - Arguments.of("http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), - Arguments.of("http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), - Arguments.of("http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data")); + Arguments.of("http://example.com/hub/", 0, "http://example.com/hub/negotiate?negotiateVersion=0"), + Arguments.of("http://example.com/hub", 1, "http://example.com/hub/negotiate?negotiateVersion=1"), + Arguments.of("http://example.com/endpoint?q=my/Data", 0, "http://example.com/endpoint/negotiate?q=my/Data&negotiateVersion=0"), + Arguments.of("http://example.com/endpoint/?q=my/Data", 1, "http://example.com/endpoint/negotiate?q=my/Data&negotiateVersion=1"), + Arguments.of("http://example.com/endpoint/path/more?q=my/Data", 0, "http://example.com/endpoint/path/more/negotiate?q=my/Data&negotiateVersion=0"), + Arguments.of("http://example.com/hub/?negotiateVersion=2", 0, "http://example.com/hub/negotiate?negotiateVersion=2")); } @ParameterizedTest @MethodSource("protocols") - public void checkNegotiateUrl(String url, String resolvedUrl) { - String urlResult = Negotiate.resolveNegotiateUrl(url); + public void checkNegotiateUrl(String url, int negotiateVersion, String resolvedUrl) { + String urlResult = Negotiate.resolveNegotiateUrl(url, negotiateVersion); assertEquals(resolvedUrl, urlResult); } } \ No newline at end of file diff --git a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts index 5bd423124b..5b1a3744cd 100644 --- a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts @@ -23,6 +23,8 @@ const enum ConnectionState { /** @private */ export interface INegotiateResponse { connectionId?: string; + connectionToken?: string; + negotiateVersion?: number; availableTransports?: IAvailableTransport[]; url?: string; accessToken?: string; @@ -70,6 +72,8 @@ export class HttpConnection implements IConnection { public onreceive: ((data: string | ArrayBuffer) => void) | null; public onclose: ((e?: Error) => void) | null; + private readonly negotiateVersion: number = 1; + constructor(url: string, options: IHttpConnectionOptions = {}) { Arg.isRequired(url, "url"); @@ -272,8 +276,6 @@ export class HttpConnection implements IConnection { throw new Error("Negotiate redirection limit exceeded."); } - this.connectionId = negotiateResponse.connectionId; - await this.createTransport(url, this.options.transport, negotiateResponse, transferFormat); } @@ -322,32 +324,41 @@ export class HttpConnection implements IConnection { return Promise.reject(new Error(`Unexpected status code returned from negotiate ${response.statusCode}`)); } - return JSON.parse(response.content as string) as INegotiateResponse; + const negotiateResponse = JSON.parse(response.content as string) as INegotiateResponse; + if (!negotiateResponse.negotiateVersion || negotiateResponse.negotiateVersion < 1) { + // Negotiate version 0 doesn't use connectionToken + // So we set it equal to connectionId so all our logic can use connectionToken without being aware of the negotiate version + negotiateResponse.connectionToken = negotiateResponse.connectionId; + } + return negotiateResponse; } catch (e) { this.logger.log(LogLevel.Error, "Failed to complete negotiation with the server: " + e); return Promise.reject(e); } } - private createConnectUrl(url: string, connectionId: string | null | undefined) { - if (!connectionId) { + private createConnectUrl(url: string, connectionToken: string | null | undefined) { + if (!connectionToken) { return url; } - return url + (url.indexOf("?") === -1 ? "?" : "&") + `id=${connectionId}`; + + return url + (url.indexOf("?") === -1 ? "?" : "&") + `id=${connectionToken}`; } private async createTransport(url: string, requestedTransport: HttpTransportType | ITransport | undefined, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat): Promise { - let connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId); + let connectUrl = this.createConnectUrl(url, negotiateResponse.connectionToken); if (this.isITransport(requestedTransport)) { this.logger.log(LogLevel.Debug, "Connection was provided an instance of ITransport, using that directly."); this.transport = requestedTransport; await this.startTransport(connectUrl, requestedTransferFormat); + this.connectionId = negotiateResponse.connectionId; return; } const transportExceptions: any[] = []; const transports = negotiateResponse.availableTransports || []; + let negotiate: INegotiateResponse | undefined = negotiateResponse; for (const endpoint of transports) { const transportOrError = this.resolveTransportOrError(endpoint, requestedTransport, requestedTransferFormat); if (transportOrError instanceof Error) { @@ -355,20 +366,21 @@ export class HttpConnection implements IConnection { transportExceptions.push(`${endpoint.transport} failed: ${transportOrError}`); } else if (this.isITransport(transportOrError)) { this.transport = transportOrError; - if (!negotiateResponse.connectionId) { + if (!negotiate) { try { - negotiateResponse = await this.getNegotiationResponse(url); + negotiate = await this.getNegotiationResponse(url); } catch (ex) { return Promise.reject(ex); } - connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId); + connectUrl = this.createConnectUrl(url, negotiate.connectionToken); } try { await this.startTransport(connectUrl, requestedTransferFormat); + this.connectionId = negotiate.connectionId; return; } catch (ex) { this.logger.log(LogLevel.Error, `Failed to start the transport '${endpoint.transport}': ${ex}`); - negotiateResponse.connectionId = undefined; + negotiate = undefined; transportExceptions.push(`${endpoint.transport} failed: ${ex}`); if (this.connectionState !== ConnectionState.Connecting) { @@ -504,7 +516,7 @@ export class HttpConnection implements IConnection { // Setting the url to the href propery of an anchor tag handles normalization // for us. There are 3 main cases. - // 1. Relative path normalization e.g "b" -> "http://localhost:5000/a/b" + // 1. Relative path normalization e.g "b" -> "http://localhost:5000/a/b" // 2. Absolute path normalization e.g "/a/b" -> "http://localhost:5000/a/b" // 3. Networkpath reference normalization e.g "//localhost:5000/a/b" -> "http://localhost:5000/a/b" const aTag = window.document.createElement("a"); @@ -522,6 +534,11 @@ export class HttpConnection implements IConnection { } negotiateUrl += "negotiate"; negotiateUrl += index === -1 ? "" : url.substring(index); + + if (negotiateUrl.indexOf("negotiateVersion") === -1) { + negotiateUrl += index === -1 ? "?" : "&"; + negotiateUrl += "negotiateVersion=" + this.negotiateVersion; + } return negotiateUrl; } } diff --git a/src/SignalR/clients/ts/signalr/tests/Common.ts b/src/SignalR/clients/ts/signalr/tests/Common.ts index 7a07a35d3d..6ce9e7779f 100644 --- a/src/SignalR/clients/ts/signalr/tests/Common.ts +++ b/src/SignalR/clients/ts/signalr/tests/Common.ts @@ -15,10 +15,10 @@ export function eachTransport(action: (transport: HttpTransportType) => void) { export function eachEndpointUrl(action: (givenUrl: string, expectedUrl: string) => void) { const urls = [ - [ "http://tempuri.org/endpoint/?q=my/Data", "http://tempuri.org/endpoint/negotiate?q=my/Data" ], - [ "http://tempuri.org/endpoint?q=my/Data", "http://tempuri.org/endpoint/negotiate?q=my/Data" ], - [ "http://tempuri.org/endpoint", "http://tempuri.org/endpoint/negotiate" ], - [ "http://tempuri.org/endpoint/", "http://tempuri.org/endpoint/negotiate" ], + [ "http://tempuri.org/endpoint/?q=my/Data", "http://tempuri.org/endpoint/negotiate?q=my/Data&negotiateVersion=1" ], + [ "http://tempuri.org/endpoint?q=my/Data", "http://tempuri.org/endpoint/negotiate?q=my/Data&negotiateVersion=1" ], + [ "http://tempuri.org/endpoint", "http://tempuri.org/endpoint/negotiate?negotiateVersion=1" ], + [ "http://tempuri.org/endpoint/", "http://tempuri.org/endpoint/negotiate?negotiateVersion=1" ], ]; urls.forEach((t) => action(t[0], t[1])); diff --git a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts index d720011334..d010fbdbc6 100644 --- a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts @@ -13,6 +13,7 @@ import { EventSourceConstructor, WebSocketConstructor } from "../src/Polyfills"; import { eachEndpointUrl, eachTransport, VerifyLogger } from "./Common"; import { TestHttpClient } from "./TestHttpClient"; import { TestTransport } from "./TestTransport"; +import { TestEvent, TestWebSocket } from "./TestWebSocket"; import { PromiseSource, registerUnhandledRejectionHandler, SyncPoint } from "./Utils"; const commonOptions: IHttpConnectionOptions = { @@ -20,6 +21,7 @@ const commonOptions: IHttpConnectionOptions = { }; const defaultConnectionId = "abc123"; +const defaultConnectionToken = "123abc"; const defaultNegotiateResponse: INegotiateResponse = { availableTransports: [ { transport: "WebSockets", transferFormats: ["Text", "Binary"] }, @@ -27,6 +29,8 @@ const defaultNegotiateResponse: INegotiateResponse = { { transport: "LongPolling", transferFormats: ["Text", "Binary"] }, ], connectionId: defaultConnectionId, + connectionToken: defaultConnectionToken, + negotiateVersion: 1, }; registerUnhandledRejectionHandler(); @@ -571,7 +575,7 @@ describe("HttpConnection", () => { let firstNegotiate = true; let firstPoll = true; const httpClient = new TestHttpClient() - .on("POST", /negotiate$/, () => { + .on("POST", /\/negotiate/, () => { if (firstNegotiate) { firstNegotiate = false; return { url: "https://another.domain.url/chat" }; @@ -602,8 +606,8 @@ describe("HttpConnection", () => { await connection.start(TransferFormat.Text); expect(httpClient.sentRequests.length).toBe(4); - expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate"); - expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate"); + expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate?negotiateVersion=1"); + expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate?negotiateVersion=1"); expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); } finally { @@ -615,7 +619,7 @@ describe("HttpConnection", () => { it("fails to start if negotiate redirects more than 100 times", async () => { await VerifyLogger.run(async (logger) => { const httpClient = new TestHttpClient() - .on("POST", /negotiate$/, () => ({ url: "https://another.domain.url/chat" })); + .on("POST", /\/negotiate/, () => ({ url: "https://another.domain.url/chat" })); const options: IHttpConnectionOptions = { ...commonOptions, @@ -637,7 +641,7 @@ describe("HttpConnection", () => { let firstNegotiate = true; let firstPoll = true; const httpClient = new TestHttpClient() - .on("POST", /negotiate$/, (r) => { + .on("POST", /\/negotiate/, (r) => { if (firstNegotiate) { firstNegotiate = false; @@ -683,8 +687,8 @@ describe("HttpConnection", () => { await connection.start(TransferFormat.Text); expect(httpClient.sentRequests.length).toBe(4); - expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate"); - expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate"); + expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate?negotiateVersion=1"); + expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate?negotiateVersion=1"); expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); } finally { @@ -696,7 +700,7 @@ describe("HttpConnection", () => { it("throws error if negotiate response has error", async () => { await VerifyLogger.run(async (logger) => { const httpClient = new TestHttpClient() - .on("POST", /negotiate$/, () => ({ error: "Negotiate error." })); + .on("POST", /\/negotiate/, () => ({ error: "Negotiate error." })); const options: IHttpConnectionOptions = { ...commonOptions, @@ -873,6 +877,253 @@ describe("HttpConnection", () => { }); }); + it("missing negotiateVersion ignores connectionToken", async () => { + await VerifyLogger.run(async (logger) => { + const availableTransport = { transport: "Custom", transferFormats: ["Text"] }; + const transport = { + connect(url: string, transferFormat: TransferFormat) { + return Promise.resolve(); + }, + send(data: any) { + return Promise.resolve(); + }, + stop() { + if (transport.onclose) { + transport.onclose(); + } + return Promise.resolve(); + }, + onclose: null, + onreceive: null, + } as ITransport; + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", () => ({ connectionId: "42", connectionToken: "token", availableTransports: [availableTransport] })), + logger, + transport, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + connection.onreceive = () => null; + try { + await connection.start(TransferFormat.Text); + expect(connection.connectionId).toBe("42"); + } finally { + await connection.stop(); + } + }); + }); + + it("negotiate version 0 ignores connectionToken", async () => { + await VerifyLogger.run(async (logger) => { + const availableTransport = { transport: "Custom", transferFormats: ["Text"] }; + const transport = { + connect(url: string, transferFormat: TransferFormat) { + return Promise.resolve(); + }, + send(data: any) { + return Promise.resolve(); + }, + stop() { + if (transport.onclose) { + transport.onclose(); + } + return Promise.resolve(); + }, + onclose: null, + onreceive: null, + } as ITransport; + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", () => ({ connectionId: "42", connectionToken: "token", negotiateVersion: 0, availableTransports: [availableTransport] })), + logger, + transport, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + connection.onreceive = () => null; + try { + await connection.start(TransferFormat.Text); + expect(connection.connectionId).toBe("42"); + } finally { + await connection.stop(); + } + }); + }); + + it("negotiate version 1 uses connectionToken for url and connectionId for property", async () => { + await VerifyLogger.run(async (logger) => { + const availableTransport = { transport: "Custom", transferFormats: ["Text"] }; + let connectUrl = ""; + const transport = { + connect(url: string, transferFormat: TransferFormat) { + connectUrl = url; + return Promise.resolve(); + }, + send(data: any) { + return Promise.resolve(); + }, + stop() { + if (transport.onclose) { + transport.onclose(); + } + return Promise.resolve(); + }, + onclose: null, + onreceive: null, + } as ITransport; + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", () => ({ connectionId: "42", connectionToken: "token", negotiateVersion: 1, availableTransports: [availableTransport] })), + logger, + transport, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + connection.onreceive = () => null; + try { + await connection.start(TransferFormat.Text); + expect(connection.connectionId).toBe("42"); + expect(connectUrl).toBe("http://tempuri.org?id=token"); + } finally { + await connection.stop(); + } + }); + }); + + it("negotiateVersion query string not added if already present", async () => { + await VerifyLogger.run(async (logger) => { + const connectUrl = new PromiseSource(); + const fakeTransport: ITransport = { + connect(url: string): Promise { + connectUrl.resolve(url); + return Promise.resolve(); + }, + send(): Promise { + return Promise.resolve(); + }, + stop(): Promise { + return Promise.resolve(); + }, + onclose: null, + onreceive: null, + }; + + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", "http://tempuri.org/negotiate?negotiateVersion=42", () => "{ \"connectionId\": \"42\" }") + .on("GET", () => ""), + logger, + transport: fakeTransport, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org?negotiateVersion=42", options); + try { + const startPromise = connection.start(TransferFormat.Text); + + expect(await connectUrl).toBe("http://tempuri.org?negotiateVersion=42&id=42"); + + await startPromise; + } finally { + (options.transport as ITransport).onclose!(); + await connection.stop(); + } + }); + }); + + it("negotiateVersion query string not added if already present after redirect", async () => { + await VerifyLogger.run(async (logger) => { + const connectUrl = new PromiseSource(); + const fakeTransport: ITransport = { + connect(url: string): Promise { + connectUrl.resolve(url); + return Promise.resolve(); + }, + send(): Promise { + return Promise.resolve(); + }, + stop(): Promise { + return Promise.resolve(); + }, + onclose: null, + onreceive: null, + }; + + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", "http://tempuri.org/negotiate?negotiateVersion=1", () => "{ \"url\": \"http://redirect.org\" }") + .on("POST", "http://redirect.org/negotiate?negotiateVersion=1", () => "{ \"connectionId\": \"42\"}") + .on("GET", () => ""), + logger, + transport: fakeTransport, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + try { + const startPromise = connection.start(TransferFormat.Text); + + expect(await connectUrl).toBe("http://redirect.org?id=42"); + + await startPromise; + } finally { + (options.transport as ITransport).onclose!(); + await connection.stop(); + } + }); + }); + + it("fallback changes connectionId property", async () => { + await VerifyLogger.run(async (logger) => { + const availableTransports = [{ transport: "WebSockets", transferFormats: ["Text"] }, { transport: "LongPolling", transferFormats: ["Text"] }]; + let negotiateCount: number = 0; + let getCount: number = 0; + let connection: HttpConnection; + let connectionId: string | undefined; + const options: IHttpConnectionOptions = { + WebSocket: TestWebSocket, + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", () => { + negotiateCount++; + return ({ connectionId: negotiateCount.toString(), connectionToken: "token", negotiateVersion: 1, availableTransports }); + }) + .on("GET", () => { + getCount++; + if (getCount === 1) { + return new HttpResponse(200); + } + connectionId = connection.connectionId; + return new HttpResponse(204); + }) + .on("DELETE", () => new HttpResponse(202)), + + logger, + } as IHttpConnectionOptions; + + TestWebSocket.webSocketSet = new PromiseSource(); + + connection = new HttpConnection("http://tempuri.org", options); + const startPromise = connection.start(TransferFormat.Text); + + await TestWebSocket.webSocketSet; + await TestWebSocket.webSocket.closeSet; + TestWebSocket.webSocket.onerror(new TestEvent()); + + try { + await startPromise; + } catch { } + + expect(negotiateCount).toEqual(2); + expect(connectionId).toEqual("2"); + }, + "Failed to start the transport 'WebSockets': Error: There was an error with the transport."); + }); + describe(".constructor", () => { it("throws if no Url is provided", async () => { // Force TypeScript to let us call the constructor incorrectly :) @@ -921,7 +1172,7 @@ describe("HttpConnection", () => { it("uses WebSocket constructor from options if provided", async () => { await VerifyLogger.run(async (logger) => { - class TestWebSocket { + class BadConstructorWebSocket { // The "_" prefix tell TypeScript not to worry about unused parameter, but tslint doesn't like it. // tslint:disable-next-line:variable-name constructor(_url: string, _protocols?: string | string[]) { @@ -931,7 +1182,7 @@ describe("HttpConnection", () => { const options: IHttpConnectionOptions = { ...commonOptions, - WebSocket: TestWebSocket as WebSocketConstructor, + WebSocket: BadConstructorWebSocket as WebSocketConstructor, logger, skipNegotiation: true, transport: HttpTransportType.WebSockets, diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnectionBuilder.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnectionBuilder.test.ts index 4aaa70a0d8..e43f3a336c 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnectionBuilder.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnectionBuilder.test.ts @@ -21,6 +21,8 @@ const longPollingNegotiateResponse = { { transport: "LongPolling", transferFormats: ["Text", "Binary"] }, ], connectionId: "abc123", + connectionToken: "123abc", + negotiateVersion: 1, }; const commonHttpOptions: IHttpConnectionOptions = { @@ -88,7 +90,7 @@ describe("HubConnectionBuilder", () => { const pollSent = new PromiseSource(); const pollCompleted = new PromiseSource(); const testClient = createTestClient(pollSent, pollCompleted.promise) - .on("POST", "http://example.com?id=abc123", (req) => { + .on("POST", "http://example.com?id=123abc", (req) => { // Respond from the poll with the handshake response pollCompleted.resolve(new HttpResponse(204, "No Content", "{}")); return new HttpResponse(202); @@ -104,7 +106,7 @@ describe("HubConnectionBuilder", () => { await expect(connection.start()).rejects.toThrow("The underlying connection was closed before the hub handshake could complete."); expect(connection.state).toBe(HubConnectionState.Disconnected); - expect((await pollSent.promise).url).toMatch(/http:\/\/example.com\?id=abc123.*/); + expect((await pollSent.promise).url).toMatch(/http:\/\/example.com\?id=123abc.*/); }); }); @@ -125,7 +127,7 @@ describe("HubConnectionBuilder", () => { const pollCompleted = new PromiseSource(); let negotiateRequest!: HttpRequest; const testClient = createTestClient(pollSent, pollCompleted.promise) - .on("POST", "http://example.com?id=abc123", (req) => { + .on("POST", "http://example.com?id=123abc", (req) => { // Respond from the poll with the handshake response negotiateRequest = req; pollCompleted.resolve(new HttpResponse(204, "No Content", "{}")); @@ -219,7 +221,7 @@ describe("HubConnectionBuilder", () => { const pollSent = new PromiseSource(); const pollCompleted = new PromiseSource(); const testClient = createTestClient(pollSent, pollCompleted.promise) - .on("POST", "http://example.com?id=abc123", (req) => { + .on("POST", "http://example.com?id=123abc", (req) => { // Respond from the poll with the handshake response pollCompleted.resolve(new HttpResponse(204, "No Content", "{}")); return new HttpResponse(202); @@ -244,7 +246,7 @@ describe("HubConnectionBuilder", () => { const pollSent = new PromiseSource(); const pollCompleted = new PromiseSource(); const testClient = createTestClient(pollSent, pollCompleted.promise) - .on("POST", "http://example.com?id=abc123", (req) => { + .on("POST", "http://example.com?id=123abc", (req) => { // Respond from the poll with the handshake response pollCompleted.resolve(new HttpResponse(204, "No Content", "{}")); return new HttpResponse(202); @@ -274,7 +276,7 @@ describe("HubConnectionBuilder", () => { const pollSent = new PromiseSource(); const pollCompleted = new PromiseSource(); const testClient = createTestClient(pollSent, pollCompleted.promise) - .on("POST", "http://example.com?id=abc123", (req) => { + .on("POST", "http://example.com?id=123abc", (req) => { // Respond from the poll with the handshake response pollCompleted.resolve(new HttpResponse(204, "No Content", "{}")); return new HttpResponse(202); @@ -413,8 +415,8 @@ function createConnectionBuilder(logger?: ILogger): HubConnectionBuilder { function createTestClient(pollSent: PromiseSource, pollCompleted: Promise, negotiateResponse?: any): TestHttpClient { let firstRequest = true; return new TestHttpClient() - .on("POST", "http://example.com/negotiate", () => negotiateResponse || longPollingNegotiateResponse) - .on("GET", /http:\/\/example.com\?id=abc123&_=.*/, (req) => { + .on("POST", "http://example.com/negotiate?negotiateVersion=1", () => negotiateResponse || longPollingNegotiateResponse) + .on("GET", /http:\/\/example.com\?id=123abc&_=.*/, (req) => { if (firstRequest) { firstRequest = false; return new HttpResponse(200); diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs index 27d206da33..f557b74f58 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs @@ -34,7 +34,9 @@ namespace Microsoft.AspNetCore.Http.Connections public string AccessToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AvailableTransports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public string ConnectionToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } } diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs index 27d206da33..f557b74f58 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs @@ -34,7 +34,9 @@ namespace Microsoft.AspNetCore.Http.Connections public string AccessToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AvailableTransports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public string ConnectionToken { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } } diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index a23a9d6c0b..a98e0ba94c 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -15,6 +15,8 @@ namespace Microsoft.AspNetCore.Http.Connections { private const string ConnectionIdPropertyName = "connectionId"; private static JsonEncodedText ConnectionIdPropertyNameBytes = JsonEncodedText.Encode(ConnectionIdPropertyName); + private const string ConnectionTokenPropertyName = "connectionToken"; + private static JsonEncodedText ConnectionTokenPropertyNameBytes = JsonEncodedText.Encode(ConnectionTokenPropertyName); private const string UrlPropertyName = "url"; private static JsonEncodedText UrlPropertyNameBytes = JsonEncodedText.Encode(UrlPropertyName); private const string AccessTokenPropertyName = "accessToken"; @@ -27,6 +29,8 @@ namespace Microsoft.AspNetCore.Http.Connections private static JsonEncodedText TransferFormatsPropertyNameBytes = JsonEncodedText.Encode(TransferFormatsPropertyName); private const string ErrorPropertyName = "error"; private static JsonEncodedText ErrorPropertyNameBytes = JsonEncodedText.Encode(ErrorPropertyName); + private const string NegotiateVersionPropertyName = "negotiateVersion"; + private static JsonEncodedText NegotiateVersionPropertyNameBytes = JsonEncodedText.Encode(NegotiateVersionPropertyName); // Use C#7.3's ReadOnlySpan optimization for static data https://vcsjones.com/2019/02/01/csharp-readonly-span-bytes-static/ // Used to detect ASP.NET SignalR Server connection attempt @@ -41,6 +45,19 @@ namespace Microsoft.AspNetCore.Http.Connections var writer = reusableWriter.GetJsonWriter(); writer.WriteStartObject(); + // If we already have an error its due to a protocol version incompatibility. + // We can just write the error and complete the JSON object and return. + if (!string.IsNullOrEmpty(response.Error)) + { + writer.WriteString(ErrorPropertyNameBytes, response.Error); + writer.WriteEndObject(); + writer.Flush(); + Debug.Assert(writer.CurrentDepth == 0); + return; + } + + writer.WriteNumber(NegotiateVersionPropertyNameBytes, response.Version); + if (!string.IsNullOrEmpty(response.Url)) { writer.WriteString(UrlPropertyNameBytes, response.Url); @@ -56,6 +73,11 @@ namespace Microsoft.AspNetCore.Http.Connections writer.WriteString(ConnectionIdPropertyNameBytes, response.ConnectionId); } + if (response.Version > 0 && !string.IsNullOrEmpty(response.ConnectionToken)) + { + writer.WriteString(ConnectionTokenPropertyNameBytes, response.ConnectionToken); + } + writer.WriteStartArray(AvailableTransportsPropertyNameBytes); if (response.AvailableTransports != null) @@ -112,10 +134,12 @@ namespace Microsoft.AspNetCore.Http.Connections reader.EnsureObjectStart(); string connectionId = null; + string connectionToken = null; string url = null; string accessToken = null; List availableTransports = null; string error = null; + int version = 0; var completed = false; while (!completed && reader.CheckRead()) @@ -135,6 +159,14 @@ namespace Microsoft.AspNetCore.Http.Connections { connectionId = reader.ReadAsString(ConnectionIdPropertyName); } + else if (reader.ValueTextEquals(ConnectionTokenPropertyNameBytes.EncodedUtf8Bytes)) + { + connectionToken = reader.ReadAsString(ConnectionTokenPropertyName); + } + else if (reader.ValueTextEquals(NegotiateVersionPropertyNameBytes.EncodedUtf8Bytes)) + { + version = reader.ReadAsInt32(NegotiateVersionPropertyName).GetValueOrDefault(); + } else if (reader.ValueTextEquals(AvailableTransportsPropertyNameBytes.EncodedUtf8Bytes)) { reader.CheckRead(); @@ -182,6 +214,14 @@ namespace Microsoft.AspNetCore.Http.Connections throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); } + if (version > 0) + { + if (connectionToken == null) + { + throw new InvalidDataException($"Missing required property '{ConnectionTokenPropertyName}'."); + } + } + if (availableTransports == null) { throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); @@ -191,10 +231,12 @@ namespace Microsoft.AspNetCore.Http.Connections return new NegotiationResponse { ConnectionId = connectionId, + ConnectionToken = connectionToken, Url = url, AccessToken = accessToken, AvailableTransports = availableTransports, Error = error, + Version = version }; } catch (Exception ex) diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index a01d2e637c..69810a5a71 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -10,6 +10,8 @@ namespace Microsoft.AspNetCore.Http.Connections public string Url { get; set; } public string AccessToken { get; set; } public string ConnectionId { get; set; } + public string ConnectionToken { get; set; } + public int Version { get; set; } public IList AvailableTransports { get; set; } public string Error { get; set; } } diff --git a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs index 7810a4985d..5ee369727c 100644 --- a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs +++ b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs @@ -53,6 +53,7 @@ namespace Microsoft.AspNetCore.Http.Connections public long ApplicationMaxBufferSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AuthorizationData { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } public Microsoft.AspNetCore.Http.Connections.LongPollingOptions LongPolling { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public int MinimumProtocolVersion { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public long TransportMaxBufferSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.Http.Connections.HttpTransportType Transports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.Http.Connections.WebSocketOptions WebSockets { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } diff --git a/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs b/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs index eff4ae76e4..e1f97d7183 100644 --- a/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs +++ b/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs @@ -57,5 +57,11 @@ namespace Microsoft.AspNetCore.Http.Connections /// Gets or sets the maximum buffer size of the application writer. /// public long ApplicationMaxBufferSize { get; set; } + + /// + /// Gets or sets the minimum protocol verison supported by the server. + /// The default value is 0, the lowest possible protocol version. + /// + public int MinimumProtocolVersion { get; set; } = 0; } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 6e21d7a665..6d3fe467e9 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -48,11 +48,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. /// The caller is expected to set the and pipes manually. /// - /// + /// + /// /// - public HttpConnectionContext(string id, ILogger logger) + public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger) { - ConnectionId = id; + ConnectionId = connectionId; + ConnectionToken = connectionToken; LastSeenUtc = DateTime.UtcNow; // The default behavior is that both formats are supported. @@ -74,8 +76,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Features.Set(this); } - public HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null) - : this(id, logger) + internal HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null) + : this(id, null, logger) { Transport = transport; Application = application; @@ -113,6 +115,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal public override string ConnectionId { get; set; } + internal string ConnectionToken { get; set; } + public override IFeatureCollection Features { get; } public ClaimsPrincipal User { get; set; } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs index af91f08af2..80f3d32800 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs @@ -52,6 +52,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private static readonly Action _failedToReadHttpRequestBody = LoggerMessage.Define(LogLevel.Debug, new EventId(14, "FailedToReadHttpRequestBody"), "Connection {TransportConnectionId} failed to read the HTTP request body."); + private static readonly Action _negotiateProtocolVersionMismatch = + LoggerMessage.Define(LogLevel.Debug, new EventId(15, "NegotiateProtocolVersionMismatch"), "The client requested version '{clientProtocolVersion}', but the server does not support this version."); + + private static readonly Action _invalidNegotiateProtocolVersion = + LoggerMessage.Define(LogLevel.Debug, new EventId(16, "InvalidNegotiateProtocolVersion"), "The client requested an invalid protocol version '{queryStringVersionValue}'"); + public static void ConnectionDisposed(ILogger logger, string connectionId) { _connectionDisposed(logger, connectionId, null); @@ -121,6 +127,16 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { _failedToReadHttpRequestBody(logger, connectionId, ex); } + + public static void NegotiateProtocolVersionMismatch(ILogger logger, int clientProtocolVersion) + { + _negotiateProtocolVersionMismatch(logger, clientProtocolVersion, null); + } + + public static void InvalidNegotiateProtocolVersion(ILogger logger, string requestedProtocolVersion) + { + _invalidNegotiateProtocolVersion(logger, requestedProtocolVersion, null); + } } } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index bf82562a7b..6ed9932839 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -45,6 +45,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private readonly HttpConnectionManager _manager; private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; + private static readonly int _protocolVersion = 1; public HttpConnectionDispatcher(HttpConnectionManager manager, ILoggerFactory loggerFactory) { @@ -58,7 +59,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // Create the log scope and attempt to pass the Connection ID to it so as many logs as possible contain // the Connection ID metadata. If this is the negotiate request then the Connection ID for the scope will // be set a little later. - var logScope = new ConnectionLogScope(GetConnectionId(context)); + + HttpConnectionContext connectionContext = null; + var connectionToken = GetConnectionToken(context); + if (connectionToken != null) + { + _manager.TryGetConnection(connectionToken, out connectionContext); + } + + var logScope = new ConnectionLogScope(connectionContext?.ConnectionId); using (_logger.BeginScope(logScope)) { if (HttpMethods.IsPost(context.Request.Method)) @@ -278,13 +287,44 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) { context.Response.ContentType = "application/json"; + string error = null; + int clientProtocolVersion = 0; + if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion)) + { + // Set the negotiate response to the protocol we use. + var queryStringVersionValue = queryStringVersion.ToString(); + if (!int.TryParse(queryStringVersionValue, out clientProtocolVersion)) + { + error = $"The client requested an invalid protocol version '{queryStringVersionValue}'"; + Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue); + } + else if (clientProtocolVersion < options.MinimumProtocolVersion) + { + error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); + } + else if (clientProtocolVersion > _protocolVersion) + { + clientProtocolVersion = _protocolVersion; + } + } + else if (options.MinimumProtocolVersion > 0) + { + // NegotiateVersion wasn't parsed meaning the client requests version 0. + error = $"The client requested version '0', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, 0); + } // Establish the connection - var connection = CreateConnection(options); + HttpConnectionContext connection = null; + if (error == null) + { + connection = CreateConnection(options, clientProtocolVersion); + } // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. - logScope.ConnectionId = connection.ConnectionId; + logScope.ConnectionId = connection?.ConnectionId; // Don't use thread static instance here because writer is used with async var writer = new MemoryBufferWriter(); @@ -292,7 +332,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal try { // Get the bytes for the connection id - WriteNegotiatePayload(writer, connection.ConnectionId, context, options); + WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error); Log.NegotiationRequest(_logger); @@ -306,10 +346,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } } - private static void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionDispatcherOptions options) + private void WriteNegotiatePayload(IBufferWriter writer, string connectionId, string connectionToken, HttpContext context, HttpConnectionDispatcherOptions options, + int clientProtocolVersion, string error) { var response = new NegotiationResponse(); + + if (!string.IsNullOrEmpty(error)) + { + response.Error = error; + NegotiateProtocol.WriteResponse(response, writer); + return; + } + + response.Version = clientProtocolVersion; response.ConnectionId = connectionId; + response.ConnectionToken = connectionToken; response.AvailableTransports = new List(); if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features)) @@ -335,7 +386,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return features.Get() != null; } - private static string GetConnectionId(HttpContext context) => context.Request.Query["id"]; + private static string GetConnectionToken(HttpContext context) => context.Request.Query["id"]; private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOptions options) { @@ -608,9 +659,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private async Task GetConnectionAsync(HttpContext context) { - var connectionId = GetConnectionId(context); + var connectionToken = GetConnectionToken(context); - if (StringValues.IsNullOrEmpty(connectionId)) + if (StringValues.IsNullOrEmpty(connectionToken)) { // There's no connection ID: bad request context.Response.StatusCode = StatusCodes.Status400BadRequest; @@ -619,7 +670,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return null; } - if (!_manager.TryGetConnection(connectionId, out var connection)) + if (!_manager.TryGetConnection(connectionToken, out var connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -634,15 +685,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // This is only used for WebSockets connections, which can connect directly without negotiating private async Task GetOrCreateConnectionAsync(HttpContext context, HttpConnectionDispatcherOptions options) { - var connectionId = GetConnectionId(context); + var connectionToken = GetConnectionToken(context); HttpConnectionContext connection; // There's no connection id so this is a brand new connection - if (StringValues.IsNullOrEmpty(connectionId)) + if (StringValues.IsNullOrEmpty(connectionToken)) { connection = CreateConnection(options); } - else if (!_manager.TryGetConnection(connectionId, out connection)) + else if (!_manager.TryGetConnection(connectionToken, out connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -653,12 +704,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return connection; } - private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options) + private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0) { var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - - return _manager.CreateConnection(transportPipeOptions, appPipeOptions); + return _manager.CreateConnection(transportPipeOptions, appPipeOptions, clientProtocolVersion); } private class EmptyServiceProvider : IServiceProvider diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index dda35866a4..4a97681fc0 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -78,18 +78,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal /// Creates a connection without Pipes setup to allow saving allocations until Pipes are needed. /// /// - internal HttpConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions) + internal HttpConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions, int negotiateVersion = 0) { + string connectionToken; var id = MakeNewConnectionId(); + if (negotiateVersion > 0) + { + connectionToken = MakeNewConnectionId(); + } + else + { + connectionToken = id; + } Log.CreatedNewConnection(_logger, id); var connectionTimer = HttpConnectionsEventSource.Log.ConnectionStart(id); - var connection = new HttpConnectionContext(id, _connectionLogger); + var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger); var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); connection.Transport = pair.Application; connection.Application = pair.Transport; - _connections.TryAdd(id, (connection, connectionTimer)); + _connections.TryAdd(connectionToken, (connection, connectionTimer)); + return connection; } @@ -205,7 +215,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { // Remove it from the list after disposal so that's it's easy to see // connections that might be in a hung state via the connections list - RemoveConnection(connection.ConnectionId); + RemoveConnection(connection.ConnectionToken); } } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 7b164b8929..6e28f47cc6 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class HttpConnectionDispatcherTests : VerifiableLoggedTest { [Fact] - public async Task NegotiateReservesConnectionIdAndReturnsIt() + public async Task NegotiateVersionZeroReservesConnectionIdAndReturnsIt() { using (StartVerifiableLog()) { @@ -54,8 +54,35 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var connectionId = negotiateResponse.Value("connectionId"); - Assert.True(manager.TryGetConnection(connectionId, out var connectionContext)); + var connectionToken = negotiateResponse.Value("connectionToken"); + Assert.Null(connectionToken); + Assert.NotNull(connectionId); + } + } + + [Fact] + public async Task NegotiateReservesConnectionTokenAndConnectionIdAndReturnsIt() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + var connectionId = negotiateResponse.Value("connectionId"); + var connectionToken = negotiateResponse.Value("connectionToken"); + Assert.True(manager.TryGetConnection(connectionToken, out var connectionContext)); Assert.Equal(connectionId, connectionContext.ConnectionId); + Assert.NotEqual(connectionId, connectionToken); } } @@ -74,12 +101,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; await dispatcher.ExecuteNegotiateAsync(context, options); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); - var connectionId = negotiateResponse.Value("connectionId"); - context.Request.QueryString = context.Request.QueryString.Add("id", connectionId); - Assert.True(manager.TryGetConnection(connectionId, out var connection)); + var connectionToken = negotiateResponse.Value("connectionToken"); + context.Request.QueryString = context.Request.QueryString.Add("id", connectionToken); + Assert.True(manager.TryGetConnection(connectionToken, out var connection)); // Fake actual connection after negotiate to populate the pipes on the connection await dispatcher.ExecuteAsync(context, options, c => Task.CompletedTask); @@ -95,6 +123,62 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task InvalidNegotiateProtocolVersionThrows() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=Invalid"); + var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; + await dispatcher.ExecuteNegotiateAsync(context, options); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + + var error = negotiateResponse.Value("error"); + Assert.Equal("The client requested an invalid protocol version 'Invalid'", error); + + var connectionId = negotiateResponse.Value("connectionId"); + Assert.Null(connectionId); + } + } + + [Fact] + public async Task NoNegotiateVersionInQueryStringThrowsWhenMinProtocolVersionIsSet() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString(""); + var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4, MinimumProtocolVersion = 1 }; + await dispatcher.ExecuteNegotiateAsync(context, options); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + + var error = negotiateResponse.Value("error"); + Assert.Equal("The client requested version '0', but the server does not support this version.", error); + + var connectionId = negotiateResponse.Value("connectionId"); + Assert.Null(connectionId); + } + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] @@ -125,7 +209,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -166,6 +251,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = transports }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -204,6 +290,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "GET"; var values = new Dictionary(); values["id"] = "unknown"; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; SetTransport(context, transportType); @@ -240,6 +327,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = "unknown"; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -276,7 +364,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -315,6 +404,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = connection.ConnectionId; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -354,7 +444,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -415,7 +506,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -481,7 +573,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -544,6 +637,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "POST"; var values = new Dictionary(); values["id"] = connection.ConnectionId; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -613,7 +707,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; values["another"] = "value"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -661,8 +756,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connectionHttpContext = connection.GetHttpContext(); Assert.NotNull(connectionHttpContext); - Assert.Equal(2, connectionHttpContext.Request.Query.Count); - Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]); + Assert.Equal(3, connectionHttpContext.Request.Query.Count); Assert.Equal("value", connectionHttpContext.Request.Query["another"]); Assert.Equal(3, connectionHttpContext.Request.Headers.Count); @@ -706,6 +800,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests services.AddSingleton(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); SetTransport(context, transportType); @@ -748,7 +843,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -775,6 +871,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests services.AddSingleton(); context.Request.Path = "/foo"; context.Request.Method = "POST"; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); @@ -846,6 +943,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); var context = MakeRequest("/foo", connection); + SetTransport(context, HttpTransportType.ServerSentEvents); var services = new ServiceCollection(); @@ -857,7 +955,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1221,7 +1319,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await task; Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1262,7 +1360,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await task; Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); - var exists = manager.TryGetConnection(connection.ConnectionId, out _); + var exists = manager.TryGetConnection(connection.ConnectionToken, out _); Assert.False(exists); } } @@ -1364,10 +1462,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; - var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1452,7 +1550,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; var ms = new MemoryStream(); deleteContext.Response.Body = ms; @@ -1495,7 +1593,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; Assert.False(pollTask.IsCompleted); @@ -1513,7 +1611,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal("text/plain", deleteContext.Response.ContentType); // Verify the connection was removed from the manager - Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out _)); } } @@ -1543,7 +1641,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); deleteContext.Request.Path = "/foo"; - deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionToken}"); deleteContext.Request.Method = "DELETE"; await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); @@ -1561,7 +1659,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await connection.DisposeAndRemoveTask.OrTimeout(); // Verify the connection was removed from the manager - Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out _)); } } @@ -1581,6 +1679,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=1"); await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = HttpTransportType.WebSockets }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); @@ -1637,7 +1736,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1693,7 +1793,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1746,7 +1847,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; var buffer = Encoding.UTF8.GetBytes("Hello, world"); @@ -1808,7 +1910,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await pollTask.OrTimeout(); Assert.Equal(StatusCodes.Status500InternalServerError, pollContext.Response.StatusCode); - Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + Assert.False(manager.TryGetConnection(connection.ConnectionToken, out var _)); } } @@ -1831,7 +1933,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -1853,14 +1956,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } - private static DefaultHttpContext MakeRequest(string path, ConnectionContext connection, string format = null) + private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, string format = null) { var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); context.Request.Path = path; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = connection.ConnectionId; + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; if (format != null) { values["format"] = format; diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index 5c30a490f7..ade605b08a 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -131,7 +131,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.ConnectionId); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); } } @@ -143,13 +143,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); - var transport = connection.Transport; Assert.NotNull(connection.ConnectionId); + Assert.NotNull(connection.ConnectionToken); Assert.NotNull(transport); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); Assert.Same(transport, newConnection.Transport); } @@ -168,12 +168,55 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.ConnectionId); Assert.NotNull(transport); - Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); Assert.Same(newConnection, connection); Assert.Same(transport, newConnection.Transport); - connectionManager.RemoveConnection(connection.ConnectionId); - Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out newConnection)); + connectionManager.RemoveConnection(connection.ConnectionToken); + Assert.False(connectionManager.TryGetConnection(connection.ConnectionToken, out newConnection)); + } + } + + [Fact] + public void ConnectionIdAndConnectionTokenAreTheSameForNegotiateVersionZero() + { + using (StartVerifiableLog()) + { + var connectionManager = CreateConnectionManager(LoggerFactory); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default, negotiateVersion: 0); + + var transport = connection.Transport; + + Assert.NotNull(connection.ConnectionId); + Assert.NotNull(transport); + + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); + Assert.Equal(connection.ConnectionId, connection.ConnectionToken); + + } + } + + [Fact] + public void ConnectionIdAndConnectionTokenAreDifferentForNegotiateVersionOne() + { + using (StartVerifiableLog()) + { + var connectionManager = CreateConnectionManager(LoggerFactory); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default, negotiateVersion: 1); + + var transport = connection.Transport; + + Assert.NotNull(connection.ConnectionId); + Assert.NotNull(transport); + + Assert.True(connectionManager.TryGetConnection(connection.ConnectionToken, out var newConnection)); + Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out var _)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); + Assert.NotEqual(connection.ConnectionId, connection.ConnectionToken); + } } diff --git a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs index e92d3c3b42..704f0f4d27 100644 --- a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs +++ b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs @@ -13,12 +13,18 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class NegotiateProtocolTests { [Theory] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null)] - [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null)] - [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null)] - [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token")] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null)] - public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken) + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 0, null)] + [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null, 0, null)] + [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null, 0, null)] + [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token", 0, null)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0, null)] + [InlineData("{\"negotiateVersion\":123,\"negotiateVersion\":321, \"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 321, "789")] + [InlineData("{\"ignore\":123,\"negotiateVersion\":123, \"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123, \"connectionToken\":\"789\"}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionId\":\"123\",\"connectionToken\":\"789\",\"availableTransports\":[]}", "123", new string[0], null, null, 0, "789")] + [InlineData("{\"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123}", "123", new string[0], null, null, 123, "789")] + [InlineData("{\"connectionToken\":\"789\",\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123, \"connectionToken\":\"987\"}", "123", new string[0], null, null, 123, "987")] + public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken, int version, string connectionToken) { var responseData = Encoding.UTF8.GetBytes(json); var response = NegotiateProtocol.ParseResponse(responseData); @@ -27,6 +33,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count); Assert.Equal(url, response.Url); Assert.Equal(accessToken, response.AccessToken); + Assert.Equal(version, response.Version); + Assert.Equal(connectionToken, response.ConnectionToken); if (response.AvailableTransports != null) { @@ -44,6 +52,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests [InlineData("{\"connectionId\":\"123\",\"availableTransports\":null}", "Unexpected JSON Token Type 'Null'. Expected a JSON Array.")] [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transferFormats\":[]}]}", "Missing required property 'transport'.")] [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\"}]}", "Missing required property 'transferFormats'.")] + [InlineData("{\"connectionId\":\"123\",\"negotiateVersion\":123,\"availableTransports\":[]}", "Missing required property 'connectionToken'.")] public void ParsingNegotiateResponseMessageThrowsForInvalid(string payload, string expectedMessage) { var responseData = Encoding.UTF8.GetBytes(payload); @@ -82,7 +91,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests string json = Encoding.UTF8.GetString(writer.ToArray()); - Assert.Equal("{\"availableTransports\":[]}", json); + Assert.Equal("{\"negotiateVersion\":0,\"availableTransports\":[]}", json); } } @@ -101,7 +110,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests string json = Encoding.UTF8.GetString(writer.ToArray()); - Assert.Equal("{\"availableTransports\":[{\"transport\":null,\"transferFormats\":[]}]}", json); + Assert.Equal("{\"negotiateVersion\":0,\"availableTransports\":[{\"transport\":null,\"transferFormats\":[]}]}", json); } } }