diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs index e04e82a2b5..fa95fbc83b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs @@ -359,7 +359,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var httpHandler = new TestHttpMessageHandler(); var connectResponseTcs = new TaskCompletionSource(); - httpHandler.OnGet("/?negotiateVersion=1&id=00000000-0000-0000-0000-000000000000", async (_, __) => + httpHandler.OnGet("/?id=00000000-0000-0000-0000-000000000000", async (_, __) => { await connectResponseTcs.Task; return ResponseUtils.CreateResponse(HttpStatusCode.Accepted); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 5abbde0313..a1392a3cb5 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -200,7 +200,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("http://fakeuri.org/?negotiateVersion=1&id=different-id", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?id=different-id", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); } [Fact] @@ -240,7 +240,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("http://fakeuri.org/?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); } [Fact] @@ -298,8 +298,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); } @@ -402,11 +402,73 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + // Delete request + Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); + } + + [Fact] + public async Task NegotiateThatReturnsRedirectUrlDoesNotAddAnotherNegotiateVersionQueryString() + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + var negotiateCount = 0; + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + negotiateCount++; + if (negotiateCount == 1) + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + url = "https://another.domain.url/chat?negotiateVersion=1" + })); + } + else + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + } + })); + } + }); + + testHttpHandler.OnLongPoll((token) => + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); + + return tcs.Task; + }); + + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + }); + } + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); - // Delete request Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs index 36596d3236..8144d9d574 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs @@ -120,7 +120,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); testHttpMessageHandler.OnRequest((request, next, cancellationToken) => { - if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.Contains("&id=")) + if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.Contains("id=")) { deleteCts.Cancel(); return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 2f99ac1b18..1fb9ba10aa 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -42,7 +42,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private readonly HttpConnectionOptions _httpConnectionOptions; private ITransport _transport; private readonly ITransportFactory _transportFactory; - private string _connectionToken; private string _connectionId; private readonly ConnectionLogScope _logScope; private readonly ILoggerFactory _loggerFactory; @@ -343,7 +342,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client } // This should only need to happen once - var connectUrl = CreateConnectUrl(uri, _connectionToken); + var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -384,10 +383,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Client if (negotiationResponse == null) { negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken); - connectUrl = CreateConnectUrl(uri, _connectionToken); + connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } - Log.StartingTransport(_logger, transportType, connectUrl); + Log.StartingTransport(_logger, transportType, uri); await StartTransport(connectUrl, transportType, transferFormat, cancellationToken); break; } @@ -430,7 +429,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client urlBuilder.Path += "/"; } urlBuilder.Path += "negotiate"; - var uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); + Uri uri; + if (urlBuilder.Query.Contains("negotiateVersion")) + { + uri = urlBuilder.Uri; + } + else + { + uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); + } using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { @@ -469,7 +476,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client throw new FormatException("Invalid connection id."); } - return Utils.AppendQueryString(url, $"negotiateVersion={_protocolVersionNumber}&id=" + connectionId); + return Utils.AppendQueryString(url, $"id={connectionId}"); } private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken) @@ -613,14 +620,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Client // If the negotiationVersion is greater than zero then we know that the negotiation response contains a // connectionToken that will be required to conenct. Otherwise we just set the connectionId and the // connectionToken on the client to the same value. - if (negotiationResponse.Version > 0) + _connectionId = negotiationResponse.ConnectionId; + if (negotiationResponse.Version == 0) { - _connectionId = negotiationResponse.ConnectionId; - _connectionToken = negotiationResponse.ConnectionToken; - } - else - { - _connectionToken = _connectionId = negotiationResponse.ConnectionId; + negotiationResponse.ConnectionToken = _connectionId; } _logScope.ConnectionId = _connectionId; diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index 0ad972053e..36dca4f808 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -57,7 +57,6 @@ public class HubConnection { private Map streamMap = new ConcurrentHashMap<>(); private TransportEnum transportEnum = TransportEnum.ALL; private String connectionId; - private String connectionToken; private final int negotiateVersion = 1; private final Logger logger = LoggerFactory.getLogger(HubConnection.class); @@ -262,7 +261,7 @@ public class HubConnection { HttpRequest request = new HttpRequest(); request.addHeaders(this.localHeaders); - return httpClient.post(Negotiate.resolveNegotiateUrl(url), request).map((response) -> { + return httpClient.post(Negotiate.resolveNegotiateUrl(url, this.negotiateVersion), request).map((response) -> { if (response.getStatusCode() != 200) { throw new RuntimeException(String.format("Unexpected status code returned from negotiate: %d %s.", response.getStatusCode(), response.getStatusText())); @@ -341,12 +340,11 @@ public class HubConnection { }); stopError = null; - String urlWithQS = Utils.appendQueryString(baseUrl, "negotiateVersion=" + negotiateVersion); Single negotiate = null; if (!skipNegotiate) { - negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(urlWithQS, 0))); + negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0))); } else { - negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(urlWithQS)))); + negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(baseUrl)))); } CompletableSubject start = CompletableSubject.create(); @@ -448,21 +446,21 @@ public class HubConnection { throw new RuntimeException("There were no compatible transports on the server."); } + String connectionToken = ""; if (response.getVersion() > 0) { this.connectionId = response.getConnectionId(); - this.connectionToken = response.getConnectionToken(); + connectionToken = response.getConnectionToken(); } else { - this.connectionToken = this.connectionId = response.getConnectionId(); + connectionToken = this.connectionId = response.getConnectionId(); } - String finalUrl = Utils.appendQueryString(url, "id=" + this.connectionToken); + String finalUrl = Utils.appendQueryString(url, "id=" + connectionToken); response.setFinalUrl(finalUrl); return Single.just(response); } - String redirectUrl = Utils.appendQueryString(response.getRedirectUrl(), "negotiateVersion=" + negotiateVersion); - return startNegotiate(redirectUrl, negotiateAttempts + 1); + return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1); }); } @@ -524,7 +522,6 @@ public class HubConnection { handshakeResponseSubject.onComplete(); redirectAccessTokenProvider = null; connectionId = null; - connectionToken = null; transportEnum = TransportEnum.ALL; this.localHeaders.clear(); this.streamMap.clear(); diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java index d177d32fb9..73dc0ddf64 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/Negotiate.java @@ -4,7 +4,7 @@ package com.microsoft.signalr; class Negotiate { - public static String resolveNegotiateUrl(String url) { + public static String resolveNegotiateUrl(String url, int negotiateVersion) { String negotiateUrl = ""; // Check if we have a query string. If we do then we ignore it for now. @@ -15,7 +15,7 @@ class Negotiate { negotiateUrl = url; } - //Check if the url ends in a / + // Check if the url ends in a / if (negotiateUrl.charAt(negotiateUrl.length() - 1) != '/') { negotiateUrl += "/"; } @@ -27,6 +27,10 @@ class Negotiate { negotiateUrl += url.substring(queryStringIndex); } + if (!url.contains("negotiateVersion")) { + negotiateUrl = Utils.appendQueryString(negotiateUrl, "negotiateVersion=" + negotiateVersion); + } + return negotiateUrl; } } diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index d23cdc7464..3f1d7d5b3f 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -1796,36 +1796,62 @@ class HubConnectionTest { hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); - assertEquals("http://example.com?negotiateVersion=1&id=connection-token-value", transport.getUrl()); + assertEquals("http://example.com?id=connection-token-value", transport.getUrl()); hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); assertNull(hubConnection.getConnectionId()); } - @Test - public void connectionTokenIsIgnoredIfNegotiateVersionIsNotPresentInNegotiateResponse() { - TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", - (req) -> Single.just(new HttpResponse(200, "", - "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + - "\"connectionToken\":\"connection-token-value\"," + - "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + @Test + public void connectionTokenIsIgnoredIfNegotiateVersionIsNotPresentInNegotiateResponse() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); - MockTransport transport = new MockTransport(true); - HubConnection hubConnection = HubConnectionBuilder - .create("http://example.com") - .withTransportImplementation(transport) - .withHttpClient(client) - .build(); + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); - assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); - assertNull(hubConnection.getConnectionId()); - hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); - assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); - assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); - assertEquals("http://example.com?negotiateVersion=1&id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); - hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); - assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); - assertNull(hubConnection.getConnectionId()); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); + assertEquals("http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + } + + @Test + public void negotiateVersionIsNotAddedIfAlreadyPresent() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=42", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + + "\"connectionToken\":\"connection-token-value\"," + + "\"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com?negotiateVersion=42") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", hubConnection.getConnectionId()); + assertEquals("http://example.com?negotiateVersion=42&id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + assertNull(hubConnection.getConnectionId()); } @Test @@ -2115,7 +2141,7 @@ class HubConnectionTest { hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); - assertEquals("http://testexample.com/?negotiateVersion=1&id=connection-token-value", transport.getUrl()); + assertEquals("http://testexample.com/?id=connection-token-value", transport.getUrl()); hubConnection.stop(); assertEquals("Bearer newToken", token.get()); } diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java index eacee30346..1c2a213f6e 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/ResolveNegotiateUrlTest.java @@ -14,17 +14,18 @@ import org.junit.jupiter.params.provider.MethodSource; class ResolveNegotiateUrlTest { private static Stream protocols() { return Stream.of( - Arguments.of("http://example.com/hub/", "http://example.com/hub/negotiate"), - Arguments.of("http://example.com/hub", "http://example.com/hub/negotiate"), - Arguments.of("http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), - Arguments.of("http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), - Arguments.of("http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data")); + Arguments.of("http://example.com/hub/", 0, "http://example.com/hub/negotiate?negotiateVersion=0"), + Arguments.of("http://example.com/hub", 1, "http://example.com/hub/negotiate?negotiateVersion=1"), + Arguments.of("http://example.com/endpoint?q=my/Data", 0, "http://example.com/endpoint/negotiate?q=my/Data&negotiateVersion=0"), + Arguments.of("http://example.com/endpoint/?q=my/Data", 1, "http://example.com/endpoint/negotiate?q=my/Data&negotiateVersion=1"), + Arguments.of("http://example.com/endpoint/path/more?q=my/Data", 0, "http://example.com/endpoint/path/more/negotiate?q=my/Data&negotiateVersion=0"), + Arguments.of("http://example.com/hub/?negotiateVersion=2", 0, "http://example.com/hub/negotiate?negotiateVersion=2")); } @ParameterizedTest @MethodSource("protocols") - public void checkNegotiateUrl(String url, String resolvedUrl) { - String urlResult = Negotiate.resolveNegotiateUrl(url); + public void checkNegotiateUrl(String url, int negotiateVersion, String resolvedUrl) { + String urlResult = Negotiate.resolveNegotiateUrl(url, negotiateVersion); assertEquals(resolvedUrl, urlResult); } } \ No newline at end of file diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 7bd4acc682..6ed9932839 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal var connectionToken = GetConnectionToken(context); if (connectionToken != null) { - _manager.TryGetConnection(GetConnectionToken(context), out connectionContext); + _manager.TryGetConnection(connectionToken, out connectionContext); } var logScope = new ConnectionLogScope(connectionContext?.ConnectionId); @@ -298,6 +298,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal error = $"The client requested an invalid protocol version '{queryStringVersionValue}'"; Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue); } + else if (clientProtocolVersion < options.MinimumProtocolVersion) + { + error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); + } + else if (clientProtocolVersion > _protocolVersion) + { + clientProtocolVersion = _protocolVersion; + } + } + else if (options.MinimumProtocolVersion > 0) + { + // NegotiateVersion wasn't parsed meaning the client requests version 0. + error = $"The client requested version '0', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, 0); } // Establish the connection @@ -343,32 +358,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return; } - if (clientProtocolVersion > 0) - { - if (clientProtocolVersion < options.MinimumProtocolVersion) - { - response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; - Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); - NegotiateProtocol.WriteResponse(response, writer); - return; - } - else if (clientProtocolVersion > _protocolVersion) - { - response.Version = _protocolVersion; - } - else - { - response.Version = clientProtocolVersion; - } - } - else if (options.MinimumProtocolVersion > 0) - { - // NegotiateVersion wasn't parsed meaning the client requests version 0. - response.Error = $"The client requested version '0', but the server does not support this version."; - NegotiateProtocol.WriteResponse(response, writer); - return; - } - + response.Version = clientProtocolVersion; response.ConnectionId = connectionId; response.ConnectionToken = connectionToken; response.AvailableTransports = new List();