diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubProtocolVersionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubProtocolVersionTests.cs new file mode 100644 index 0000000000..5e6b7a8295 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubProtocolVersionTests.cs @@ -0,0 +1,230 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Http.Connections.Client; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Microsoft.Extensions.Options; +using Newtonsoft.Json.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests +{ + // Disable running server tests in parallel so server logs can accurately be captured per test + [CollectionDefinition(Name, DisableParallelization = true)] + public class HubProtocolVersionTestsCollection : ICollectionFixture> + { + public const string Name = nameof(HubProtocolVersionTestsCollection); + } + + [Collection(HubProtocolVersionTestsCollection.Name)] + public class HubProtocolVersionTests : VerifiableServerLoggedTest + { + public HubProtocolVersionTests(ServerFixture serverFixture, ITestOutputHelper output) : base(serverFixture, output) + { + } + + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientUsingOldCallWithOriginalProtocol(HttpTransportType transportType) + { + using (StartVerifiableLog(out var loggerFactory, $"{nameof(ClientUsingOldCallWithOriginalProtocol)}_{transportType}")) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(ServerFixture.Url + "/version", transportType); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().OrTimeout(); + + var result = await connection.InvokeAsync(nameof(VersionHub.Echo), "Hello World!").OrTimeout(); + + Assert.Equal("Hello World!", result); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientUsingOldCallWithNewProtocol(HttpTransportType transportType) + { + using (StartVerifiableLog(out var loggerFactory, $"{nameof(ClientUsingOldCallWithNewProtocol)}_{transportType}")) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(ServerFixture.Url + "/version", transportType); + connectionBuilder.Services.AddSingleton(new VersionedJsonHubProtocol(1000)); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().OrTimeout(); + + var result = await connection.InvokeAsync(nameof(VersionHub.Echo), "Hello World!").OrTimeout(); + + Assert.Equal("Hello World!", result); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientUsingNewCallWithNewProtocol(HttpTransportType transportType) + { + using (StartVerifiableLog(out var loggerFactory, $"{nameof(ClientUsingNewCallWithNewProtocol)}_{transportType}")) + { + var httpConnectionFactory = new HttpConnectionFactory(Options.Create(new HttpConnectionOptions + { + Url = new Uri(ServerFixture.Url + "/version"), + Transports = transportType + }), loggerFactory); + var tcs = new TaskCompletionSource(); + + var proxyConnectionFactory = new ProxyConnectionFactory(httpConnectionFactory); + + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory); + connectionBuilder.Services.AddSingleton(new VersionedJsonHubProtocol(1000)); + connectionBuilder.Services.AddSingleton(proxyConnectionFactory); + + var connection = connectionBuilder.Build(); + connection.On("NewProtocolMethodClient", () => + { + tcs.SetResult(null); + }); + + try + { + await connection.StartAsync().OrTimeout(); + + // Task should already have been awaited in StartAsync + var connectionContext = await proxyConnectionFactory.ConnectTask.OrTimeout(); + + // Simulate a new call from the client + var messageToken = new JObject + { + ["type"] = int.MaxValue + }; + + connectionContext.Transport.Output.Write(Encoding.UTF8.GetBytes(messageToken.ToString())); + connectionContext.Transport.Output.Write(new[] { (byte)0x1e }); + await connectionContext.Transport.Output.FlushAsync().OrTimeout(); + + await tcs.Task.OrTimeout(); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientWithUnsupportedProtocolVersionDoesNotConnect(HttpTransportType transportType) + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName; + } + + using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace, $"{nameof(ClientWithUnsupportedProtocolVersionDoesNotConnect)}_{transportType}", expectedErrorsFilter: ExpectedErrors)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(ServerFixture.Url + "/version", transportType); + connectionBuilder.Services.AddSingleton(new VersionedJsonHubProtocol(int.MaxValue)); + + var connection = connectionBuilder.Build(); + + try + { + await ExceptionAssert.ThrowsAsync( + () => connection.StartAsync(), + "Unable to complete handshake with the server due to an error: The server does not support version 2147483647 of the 'json' protocol.").OrTimeout(); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + private class ProxyConnectionFactory : IConnectionFactory + { + private readonly IConnectionFactory _innerFactory; + public Task ConnectTask { get; private set; } + + public ProxyConnectionFactory(IConnectionFactory innerFactory) + { + _innerFactory = innerFactory; + } + + public Task ConnectAsync(TransferFormat transferFormat, CancellationToken cancellationToken = default) + { + ConnectTask = _innerFactory.ConnectAsync(transferFormat, cancellationToken); + return ConnectTask; + } + + public Task DisposeAsync(ConnectionContext connection) + { + return _innerFactory.DisposeAsync(connection); + } + } + + public static IEnumerable TransportTypes() + { + if (TestHelpers.IsWebSocketsSupported()) + { + yield return new object[] { HttpTransportType.WebSockets }; + } + yield return new object[] { HttpTransportType.ServerSentEvents }; + yield return new object[] { HttpTransportType.LongPolling }; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 3e4fb69a3c..e79babd306 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -178,6 +178,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests Task NoClientHandler(); } + public class VersionHub : Hub + { + public string Echo(string message) => message; + + public Task NewProtocolMethodServer() + { + return Clients.Caller.SendAsync("NewProtocolMethodClient"); + } + } + [Authorize(JwtBearerDefaults.AuthenticationScheme)] public class HubWithAuthorization : Hub { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionJsonHubProtocol.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionJsonHubProtocol.cs new file mode 100644 index 0000000000..86766b088b --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionJsonHubProtocol.cs @@ -0,0 +1,86 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Protocol; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests +{ + public class VersionedJsonHubProtocol : IHubProtocol + { + private readonly int _version; + private readonly JsonHubProtocol _innerProtocol; + + public VersionedJsonHubProtocol(int version) + { + _version = version; + _innerProtocol = new JsonHubProtocol(); + } + + public string Name => _innerProtocol.Name; + public int Version => _version; + public TransferFormat TransferFormat => _innerProtocol.TransferFormat; + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + var inputCopy = input; + if (!TryParseMessage(ref input, out var payload)) + { + message = null; + return false; + } + + // Handle "new" call + var json = Encoding.UTF8.GetString(payload.ToArray()); + var o = JObject.Parse(json); + if ((int)o["type"] == int.MaxValue) + { + message = new InvocationMessage("NewProtocolMethodServer", Array.Empty()); + return true; + } + + // Handle "old" calls + var result = _innerProtocol.TryParseMessage(ref inputCopy, binder, out message); + input = inputCopy; + return result; + } + + public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload) + { + var position = buffer.PositionOf((byte)0x1e); + if (position == null) + { + payload = default; + return false; + } + + payload = buffer.Slice(0, position.Value); + + // Skip record separator + buffer = buffer.Slice(buffer.GetPosition(1, position.Value)); + + return true; + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + _innerProtocol.WriteMessage(message, output); + } + + public ReadOnlyMemory GetMessageBytes(HubMessage message) + { + return _innerProtocol.GetMessageBytes(message); + } + + public bool IsVersionSupported(int version) + { + // Support older clients + return version <= _version; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionStartup.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionStartup.cs new file mode 100644 index 0000000000..18972ebb87 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/VersionStartup.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Claims; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.IdentityModel.Tokens; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests +{ + public class VersionStartup + { + public void ConfigureServices(IServiceCollection services) + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + }); + + services.RemoveAll(); + services.TryAddEnumerable(ServiceDescriptor.Singleton(new VersionedJsonHubProtocol(1000))); + + services.AddAuthentication(); + } + + public void Configure(IApplicationBuilder app) + { + app.UseAuthentication(); + + app.UseSignalR(routes => + { + routes.MapHub("/version"); + }); + } + } +}