diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 86fe4b32ac..b64804a9d6 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) { if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer && - hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer]) + descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType())) { // The types match so it isn't a synthetic argument, just copy it into the arguments array arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]; diff --git a/src/SignalR/server/SignalR/test/Internal/DefaultHubDispatcherTests.cs b/src/SignalR/server/SignalR/test/Internal/DefaultHubDispatcherTests.cs new file mode 100644 index 0000000000..f28ec35cb1 --- /dev/null +++ b/src/SignalR/server/SignalR/test/Internal/DefaultHubDispatcherTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal +{ + public class DefaultHubDispatcherTests + { + private class MockHubConnectionContext : HubConnectionContext + { + public TaskCompletionSource ReceivedCompleted = new TaskCompletionSource(); + public List Values = new List(); + + public MockHubConnectionContext(ConnectionContext connectionContext, HubConnectionContextOptions contextOptions, ILoggerFactory loggerFactory) + : base(connectionContext, contextOptions, loggerFactory) { } + + public override ValueTask WriteAsync(HubMessage message, CancellationToken cancellationToken) + { + if (message is StreamItemMessage streamItemMessage) + Values.Add((TValue)streamItemMessage.Item); + + else if (message is CompletionMessage completionMessage) + { + ReceivedCompleted.TrySetResult(null); + + if (!string.IsNullOrEmpty(completionMessage.Error)) + { + throw new Exception("Error invoking hub method: " + completionMessage.Error); + } + } + + else throw new NotImplementedException(); + + return default; + } + } + + private static DefaultHubDispatcher CreateDispatcher() where THub : Hub + { + var serviceCollection = new ServiceCollection(); + serviceCollection.TryAddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>)); + var provider = serviceCollection.BuildServiceProvider(); + var serviceScopeFactory = provider.GetService(); + + return new DefaultHubDispatcher( + serviceScopeFactory, + new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), + Options.Create(new HubOptions()), + Options.Create(new HubOptions()), + new Logger>(NullLoggerFactory.Instance)); + } + + private static MockHubConnectionContext CreateConnectionContext() + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); + var contextOptions = new HubConnectionContextOptions() { KeepAliveInterval = TimeSpan.Zero }; + + return new MockHubConnectionContext( + connection, + contextOptions, + NullLoggerFactory.Instance); + } + + /// + /// For . + /// + private interface ITestDerivedParameter + { + public string Value { get; } + } + + /// + /// For . + /// + private abstract class TestDerivedParameterBase + { + public TestDerivedParameterBase(string value) => Value = value; + public string Value { get; } + } + + /// + /// For . + /// + private class TestDerivedParameter : TestDerivedParameterBase, ITestDerivedParameter + { + public TestDerivedParameter(string value) : base(value) { } + } + + /// + /// For . + /// + private class TestDerivedParameterHub : Hub + { + public async IAsyncEnumerable TestSubclass(TestDerivedParameterBase param, [EnumeratorCancellation]CancellationToken token) + { + await Task.Yield(); + yield return param.Value; + } + + public async IAsyncEnumerable TestImplementation(ITestDerivedParameter param, [EnumeratorCancellation]CancellationToken token) + { + await Task.Yield(); + yield return param.Value; + } + } + + /// + /// Hub methods might be written by users in a way that accepts an interface or base class as a parameter + /// and deserialization could supply a derived class (e.g. Json.NET's TypeNameHandling = TypeNameHandling.All). + /// This test ensures implementation and subclass arguments are correctly bound for dispatch. + /// + [Theory] + [InlineData(nameof(TestDerivedParameterHub.TestImplementation))] + [InlineData(nameof(TestDerivedParameterHub.TestSubclass))] + public async Task DispatchesDerivedArguments(string methodName) + { + var message = new TestDerivedParameter("Yup"); + var connectionContext = CreateConnectionContext(); + var dispatcher = CreateDispatcher(); + + await dispatcher.DispatchMessageAsync(connectionContext, new StreamInvocationMessage("123", methodName, new[] { message })); + await (connectionContext as MockHubConnectionContext).ReceivedCompleted.Task; + + Assert.Single(connectionContext.Values, message.Value); + } + } +}