From ea8c2b92f4b273bcad1f9246ca87cd4209f10657 Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Tue, 3 Sep 2019 19:53:32 -0700 Subject: [PATCH] SignalR Negotiate protocol versioning (#13389) --- .../FunctionalTests/HubConnectionTests.cs | 65 +++++++++++++++++++ .../Client/test/FunctionalTests/Startup.cs | 10 +++ ...HttpConnectionTests.ConnectionLifecycle.cs | 2 +- .../HttpConnectionTests.Negotiate.cs | 65 +++++++++++++++---- .../test/UnitTests/TestHttpMessageHandler.cs | 5 +- .../src/HttpConnection.cs | 6 +- ...Core.Http.Connections.Common.netcoreapp.cs | 1 + ....Http.Connections.Common.netstandard2.0.cs | 1 + .../src/NegotiateProtocol.cs | 21 ++++++ .../src/NegotiationResponse.cs | 1 + ....AspNetCore.Http.Connections.netcoreapp.cs | 1 + .../src/HttpConnectionDispatcherOptions.cs | 6 ++ .../Internal/HttpConnectionDispatcher.Log.cs | 16 +++++ .../src/Internal/HttpConnectionDispatcher.cs | 42 +++++++++++- .../test/HttpConnectionDispatcherTests.cs | 58 +++++++++++++++++ .../test/NegotiateProtocolTests.cs | 22 ++++--- 16 files changed, 295 insertions(+), 27 deletions(-) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 14e9d78e5b..6dbd4e032e 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -114,6 +114,71 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task ServerRejectsClientWithOldProtocol() + { + bool ExpectedError(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + var protocol = HubProtocols["json"]; + using (StartServer(out var server, ExpectedError)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/negotiateProtocolVersion12", HttpTransportType.LongPolling); + connectionBuilder.Services.AddSingleton(protocol); + + var connection = connectionBuilder.Build(); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => connection.StartAsync()).OrTimeout(); + Assert.Equal("The client requested version '1', but the server does not support this version.", ex.Message); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Fact] + public async Task ClientCanConnectToServerWithLowerMinimumProtocol() + { + var protocol = HubProtocols["json"]; + using (StartServer(out var server)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/negotiateProtocolVersionNegative", HttpTransportType.LongPolling); + connectionBuilder.Services.AddSingleton(protocol); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().OrTimeout(); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task CanSendAndReceiveMessage(string protocolName, HttpTransportType transportType, string path) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs index 1d7dbd6718..4cbc35c510 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs @@ -69,6 +69,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests endpoints.MapHub("/default-nowebsockets", options => options.Transports = HttpTransportType.LongPolling | HttpTransportType.ServerSentEvents); + endpoints.MapHub("/negotiateProtocolVersion12", options => + { + options.MinimumProtocolVersion = 12; + }); + + endpoints.MapHub("/negotiateProtocolVersionNegative", options => + { + options.MinimumProtocolVersion = -1; + }); + endpoints.MapGet("/generateJwtToken", context => { return context.Response.WriteAsync(GenerateJwtToken()); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs index fa95fbc83b..e04e82a2b5 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs @@ -359,7 +359,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var httpHandler = new TestHttpMessageHandler(); var connectResponseTcs = new TaskCompletionSource(); - httpHandler.OnGet("/?id=00000000-0000-0000-0000-000000000000", async (_, __) => + httpHandler.OnGet("/?negotiateVersion=1&id=00000000-0000-0000-0000-000000000000", async (_, __) => { await connectResponseTcs.Task; return ResponseUtils.CreateResponse(HttpStatusCode.Accepted); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 348e33cebf..47a16a48d9 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -50,12 +50,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Theory] - [InlineData("http://fakeuri.org/", "http://fakeuri.org/negotiate")] - [InlineData("http://fakeuri.org/?q=1/0", "http://fakeuri.org/negotiate?q=1/0")] - [InlineData("http://fakeuri.org?q=1/0", "http://fakeuri.org/negotiate?q=1/0")] - [InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate")] - [InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate")] - [InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0")] + [InlineData("http://fakeuri.org/", "http://fakeuri.org/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/?q=1/0", "http://fakeuri.org/negotiate?q=1/0&negotiateVersion=1")] + [InlineData("http://fakeuri.org?q=1/0", "http://fakeuri.org/negotiate?q=1/0&negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate?negotiateVersion=1")] + [InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0&negotiateVersion=1")] public async Task CorrectlyHandlesQueryStringWhenAppendingNegotiateToUrl(string requestedUrl, string expectedNegotiate) { var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); @@ -119,6 +119,43 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); } + [Fact] + public async Task NegotiateCanHaveNewFields() + { + string connectionId = null; + + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + }, + newField = "ignore this", + }))); + testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync().OrTimeout(); + connectionId = connection.ConnectionId; + }); + } + + Assert.Equal("0rge0d00-0040-0030-0r00-000q00r00e00", connectionId); + } + [Fact] public async Task NegotiateThatReturnsUrlGetFollowed() { @@ -172,10 +209,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); } @@ -278,10 +315,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } - Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); - Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + Assert.Equal("http://fakeuri.org/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate?negotiateVersion=1", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); // Delete request Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs index 06d05da7f5..36596d3236 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestHttpMessageHandler.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using System.Collections.Generic; using System.Net; @@ -117,7 +120,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); testHttpMessageHandler.OnRequest((request, next, cancellationToken) => { - if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.StartsWith("/?id=")) + if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.Contains("&id=")) { deleteCts.Cancel(); return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 852681d963..bcd2e2dcc1 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -26,6 +26,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client // Not configurable on purpose, high enough that if we reach here, it's likely // a buggy server private static readonly int _maxRedirects = 100; + private static readonly int _protocolVersionNumber = 1; private static readonly Task _noAccessToken = Task.FromResult(null); private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120); @@ -428,8 +429,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client urlBuilder.Path += "/"; } urlBuilder.Path += "negotiate"; + var uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); - using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri)) + using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { // Corefx changed the default version and High Sierra curlhandler tries to upgrade request request.Version = new Version(1, 1); @@ -466,7 +468,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client throw new FormatException("Invalid connection id."); } - return Utils.AppendQueryString(url, "id=" + connectionId); + return Utils.AppendQueryString(url, $"negotiateVersion={_protocolVersionNumber}&id=" + connectionId); } private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken) diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs index 27d206da33..fb58827e4b 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netcoreapp.cs @@ -36,5 +36,6 @@ namespace Microsoft.AspNetCore.Http.Connections public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } } diff --git a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs index 27d206da33..fb58827e4b 100644 --- a/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs +++ b/src/SignalR/common/Http.Connections.Common/ref/Microsoft.AspNetCore.Http.Connections.Common.netstandard2.0.cs @@ -36,5 +36,6 @@ namespace Microsoft.AspNetCore.Http.Connections public string ConnectionId { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Error { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public string Url { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public int Version { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } } diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index a23a9d6c0b..9bb79fd055 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -27,6 +27,8 @@ namespace Microsoft.AspNetCore.Http.Connections private static JsonEncodedText TransferFormatsPropertyNameBytes = JsonEncodedText.Encode(TransferFormatsPropertyName); private const string ErrorPropertyName = "error"; private static JsonEncodedText ErrorPropertyNameBytes = JsonEncodedText.Encode(ErrorPropertyName); + private const string NegotiateVersionPropertyName = "negotiateVersion"; + private static JsonEncodedText NegotiateVersionPropertyNameBytes = JsonEncodedText.Encode(NegotiateVersionPropertyName); // Use C#7.3's ReadOnlySpan optimization for static data https://vcsjones.com/2019/02/01/csharp-readonly-span-bytes-static/ // Used to detect ASP.NET SignalR Server connection attempt @@ -41,6 +43,19 @@ namespace Microsoft.AspNetCore.Http.Connections var writer = reusableWriter.GetJsonWriter(); writer.WriteStartObject(); + // If we already have an error its due to a protocol version incompatibility. + // We can just write the error and complete the JSON object and return. + if (!string.IsNullOrEmpty(response.Error)) + { + writer.WriteString(ErrorPropertyNameBytes, response.Error); + writer.WriteEndObject(); + writer.Flush(); + Debug.Assert(writer.CurrentDepth == 0); + return; + } + + writer.WriteNumber(NegotiateVersionPropertyNameBytes, response.Version); + if (!string.IsNullOrEmpty(response.Url)) { writer.WriteString(UrlPropertyNameBytes, response.Url); @@ -116,6 +131,7 @@ namespace Microsoft.AspNetCore.Http.Connections string accessToken = null; List availableTransports = null; string error = null; + int version = 0; var completed = false; while (!completed && reader.CheckRead()) @@ -135,6 +151,10 @@ namespace Microsoft.AspNetCore.Http.Connections { connectionId = reader.ReadAsString(ConnectionIdPropertyName); } + else if (reader.ValueTextEquals(NegotiateVersionPropertyNameBytes.EncodedUtf8Bytes)) + { + version = reader.ReadAsInt32(NegotiateVersionPropertyName).GetValueOrDefault(); + } else if (reader.ValueTextEquals(AvailableTransportsPropertyNameBytes.EncodedUtf8Bytes)) { reader.CheckRead(); @@ -195,6 +215,7 @@ namespace Microsoft.AspNetCore.Http.Connections AccessToken = accessToken, AvailableTransports = availableTransports, Error = error, + Version = version }; } catch (Exception ex) diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index a01d2e637c..cd21b6cb26 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Http.Connections public string Url { get; set; } public string AccessToken { get; set; } public string ConnectionId { get; set; } + public int Version { get; set; } public IList AvailableTransports { get; set; } public string Error { get; set; } } diff --git a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs index 7810a4985d..5ee369727c 100644 --- a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs +++ b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp.cs @@ -53,6 +53,7 @@ namespace Microsoft.AspNetCore.Http.Connections public long ApplicationMaxBufferSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList AuthorizationData { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } public Microsoft.AspNetCore.Http.Connections.LongPollingOptions LongPolling { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public int MinimumProtocolVersion { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public long TransportMaxBufferSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.Http.Connections.HttpTransportType Transports { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.Http.Connections.WebSocketOptions WebSockets { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } diff --git a/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs b/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs index eff4ae76e4..e1f97d7183 100644 --- a/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs +++ b/src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs @@ -57,5 +57,11 @@ namespace Microsoft.AspNetCore.Http.Connections /// Gets or sets the maximum buffer size of the application writer. /// public long ApplicationMaxBufferSize { get; set; } + + /// + /// Gets or sets the minimum protocol verison supported by the server. + /// The default value is 0, the lowest possible protocol version. + /// + public int MinimumProtocolVersion { get; set; } = 0; } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs index af91f08af2..80f3d32800 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.Log.cs @@ -52,6 +52,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private static readonly Action _failedToReadHttpRequestBody = LoggerMessage.Define(LogLevel.Debug, new EventId(14, "FailedToReadHttpRequestBody"), "Connection {TransportConnectionId} failed to read the HTTP request body."); + private static readonly Action _negotiateProtocolVersionMismatch = + LoggerMessage.Define(LogLevel.Debug, new EventId(15, "NegotiateProtocolVersionMismatch"), "The client requested version '{clientProtocolVersion}', but the server does not support this version."); + + private static readonly Action _invalidNegotiateProtocolVersion = + LoggerMessage.Define(LogLevel.Debug, new EventId(16, "InvalidNegotiateProtocolVersion"), "The client requested an invalid protocol version '{queryStringVersionValue}'"); + public static void ConnectionDisposed(ILogger logger, string connectionId) { _connectionDisposed(logger, connectionId, null); @@ -121,6 +127,16 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { _failedToReadHttpRequestBody(logger, connectionId, ex); } + + public static void NegotiateProtocolVersionMismatch(ILogger logger, int clientProtocolVersion) + { + _negotiateProtocolVersionMismatch(logger, clientProtocolVersion, null); + } + + public static void InvalidNegotiateProtocolVersion(ILogger logger, string requestedProtocolVersion) + { + _invalidNegotiateProtocolVersion(logger, requestedProtocolVersion, null); + } } } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index bf82562a7b..983b1270f6 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -45,6 +45,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private readonly HttpConnectionManager _manager; private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; + private static readonly int _protocolVersion = 1; public HttpConnectionDispatcher(HttpConnectionManager manager, ILoggerFactory loggerFactory) { @@ -306,9 +307,48 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } } - private static void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionDispatcherOptions options) + private void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionDispatcherOptions options) { var response = new NegotiationResponse(); + + if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion)) + { + // Set the negotiate response to the protocol we use. + var queryStringVersionValue = queryStringVersion.ToString(); + if (int.TryParse(queryStringVersionValue, out var clientProtocolVersion)) + { + if (clientProtocolVersion < options.MinimumProtocolVersion) + { + response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version."; + Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion); + NegotiateProtocol.WriteResponse(response, writer); + return; + } + else if (clientProtocolVersion > _protocolVersion) + { + response.Version = _protocolVersion; + } + else + { + response.Version = clientProtocolVersion; + } + } + else + { + response.Error = $"The client requested an invalid protocol version '{queryStringVersionValue}'"; + Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue); + NegotiateProtocol.WriteResponse(response, writer); + return; + } + } + else if (options.MinimumProtocolVersion > 0) + { + // NegotiateVersion wasn't parsed meaning the client requests version 0. + response.Error = $"The client requested version '0', but the server does not support this version."; + NegotiateProtocol.WriteResponse(response, writer); + return; + } + response.ConnectionId = connectionId; response.AvailableTransports = new List(); diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 7b164b8929..c01be04894 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -95,6 +95,64 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task InvalidNegotiateProtocolVersionThrows() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString("?negotiateVersion=Invalid"); + + var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; + await dispatcher.ExecuteNegotiateAsync(context, options); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + + var error = negotiateResponse.Value("error"); + Assert.Equal("The client requested an invalid protocol version 'Invalid'", error); + + var connectionId = negotiateResponse.Value("connectionId"); + Assert.Null(connectionId); + } + } + + [Fact] + public async Task NoNegotiateVersionInQueryStringThrowsWhenMinProtocolVersionIsSet() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString(""); + + var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4, MinimumProtocolVersion = 1 }; + await dispatcher.ExecuteNegotiateAsync(context, options); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + + var error = negotiateResponse.Value("error"); + Assert.Equal("The client requested version '0', but the server does not support this version.", error); + + var connectionId = negotiateResponse.Value("connectionId"); + Assert.Null(connectionId); + } + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] diff --git a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs index e92d3c3b42..c7d274d0b3 100644 --- a/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs +++ b/src/SignalR/common/Http.Connections/test/NegotiateProtocolTests.cs @@ -13,12 +13,17 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class NegotiateProtocolTests { [Theory] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null)] - [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null)] - [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null)] - [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token")] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null)] - public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken) + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 0)] + [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null, 0)] + [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null, 0)] + [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token", 0)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 0)] + [InlineData("{\"negotiateVersion\":123,\"connectionId\":\"123\",\"availableTransports\":[{\"\\u0074ransport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null, 123)] + [InlineData("{\"negotiateVersion\":123,\"negotiateVersion\":321,\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 321)] + [InlineData("{\"ignore\":123,\"negotiateVersion\":123,\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null, 123)] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[],\"negotiateVersion\":123}", "123", new string[0], null, null, 123)] + public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken, int version) { var responseData = Encoding.UTF8.GetBytes(json); var response = NegotiateProtocol.ParseResponse(responseData); @@ -27,6 +32,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count); Assert.Equal(url, response.Url); Assert.Equal(accessToken, response.AccessToken); + Assert.Equal(version, response.Version); if (response.AvailableTransports != null) { @@ -82,7 +88,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests string json = Encoding.UTF8.GetString(writer.ToArray()); - Assert.Equal("{\"availableTransports\":[]}", json); + Assert.Equal("{\"negotiateVersion\":0,\"availableTransports\":[]}", json); } } @@ -101,7 +107,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests string json = Encoding.UTF8.GetString(writer.ToArray()); - Assert.Equal("{\"availableTransports\":[{\"transport\":null,\"transferFormats\":[]}]}", json); + Assert.Equal("{\"negotiateVersion\":0,\"availableTransports\":[{\"transport\":null,\"transferFormats\":[]}]}", json); } } }