diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 1a21ad47f1..f1975e12bc 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) { if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer && - hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer]) + descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType())) { // The types match so it isn't a synthetic argument, just copy it into the arguments array arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]; diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 504c2cc751..89c7f77498 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -3,12 +3,14 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; +using Newtonsoft.Json.Serialization; namespace Microsoft.AspNetCore.SignalR.Tests { @@ -666,6 +668,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests return output.Reader; } + public async IAsyncEnumerable DerivedParameterInterfaceAsyncEnumerable(IDerivedParameterTestObject param) + { + await Task.Yield(); + yield return param.Value; + } + + public async IAsyncEnumerable DerivedParameterBaseClassAsyncEnumerable(DerivedParameterTestObjectBase param) + { + await Task.Yield(); + yield return param.Value; + } + + public async IAsyncEnumerable DerivedParameterInterfaceAsyncEnumerableWithCancellation(IDerivedParameterTestObject param, [EnumeratorCancellation] CancellationToken token) + { + await Task.Yield(); + yield return param.Value; + } + + public async IAsyncEnumerable DerivedParameterBaseClassAsyncEnumerableWithCancellation(DerivedParameterTestObjectBase param, [EnumeratorCancellation] CancellationToken token) + { + await Task.Yield(); + yield return param.Value; + } + public class AsyncEnumerableImpl : IAsyncEnumerable { private readonly IAsyncEnumerable _inner; @@ -758,6 +784,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } } + + public interface IDerivedParameterTestObject + { + public string Value { get; set; } + } + + public abstract class DerivedParameterTestObjectBase : IDerivedParameterTestObject + { + public string Value { get; set; } + } + + public class DerivedParameterTestObject : DerivedParameterTestObjectBase { } + + public class DerivedParameterKnownTypesBinder : ISerializationBinder + { + private static readonly IEnumerable _knownTypes = new List() + { + typeof(DerivedParameterTestObject) + }; + + public static ISerializationBinder Instance { get; } = new DerivedParameterKnownTypesBinder(); + + public void BindToName(Type serializedType, out string assemblyName, out string typeName) + { + assemblyName = null; + typeName = serializedType.Name; + } + + public Type BindToType(string assemblyName, string typeName) => + _knownTypes.Single(type => type.Name == typeName); + } } public class SimpleHub : Hub diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 1c047973a3..c61a1b1c00 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -3641,6 +3641,63 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + /// + /// Hub methods might be written by users in a way that accepts an interface or base class as a parameter + /// and deserialization could supply a derived class. + /// This test ensures implementation and subclass arguments are correctly bound for dispatch. + /// + [Theory] + [InlineData(nameof(StreamingHub.DerivedParameterInterfaceAsyncEnumerable))] + [InlineData(nameof(StreamingHub.DerivedParameterBaseClassAsyncEnumerable))] + [InlineData(nameof(StreamingHub.DerivedParameterInterfaceAsyncEnumerableWithCancellation))] + [InlineData(nameof(StreamingHub.DerivedParameterBaseClassAsyncEnumerableWithCancellation))] + public async Task CanPassDerivedParameterToStreamHubMethod(string method) + { + using (StartVerifiableLog()) + { + var argument = new StreamingHub.DerivedParameterTestObject { Value = "test" }; + var protocolOptions = new NewtonsoftJsonHubProtocolOptions + { + PayloadSerializerSettings = new JsonSerializerSettings() + { + // The usage of TypeNameHandling.All is a security risk. + // If you're implementing this in your own application instead use your own 'type' field and a custom JsonConverter + // or ensure you're restricting to only known types with a custom SerializationBinder like we are here. + // See https://github.com/aspnet/AspNetCore/issues/11495#issuecomment-505047422 + TypeNameHandling = TypeNameHandling.All, + SerializationBinder = StreamingHub.DerivedParameterKnownTypesBinder.Instance + } + }; + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider( + services => services.AddSignalR() + .AddNewtonsoftJsonProtocol(o => o.PayloadSerializerSettings = protocolOptions.PayloadSerializerSettings), + LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); + + using (var client = new TestClient( + protocol: new NewtonsoftJsonHubProtocol(Options.Create(protocolOptions)), + invocationBinder: invocationBinder.Object)) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).OrTimeout(); + + var messages = await client.StreamAsync(method, argument).OrTimeout(); + + Assert.Equal(2, messages.Count); + HubConnectionHandlerTestUtils.AssertHubMessage(new StreamItemMessage(string.Empty, argument.Value), messages[0]); + HubConnectionHandlerTestUtils.AssertHubMessage(CompletionMessage.Empty(string.Empty), messages[1]); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + private class CustomHubActivator : IHubActivator where THub : Hub { public int ReleaseCount;