diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubClients`T.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubClients`T.cs index 69926e9695..2b2aca344f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubClients`T.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubClients`T.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/TypedClientBuilder.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/TypedClientBuilder.cs index bcae8fec8c..7f416b9ea6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/TypedClientBuilder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/TypedClientBuilder.cs @@ -1,9 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using System.Reflection.Emit; @@ -26,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal public static void Validate() { // The following will throw if T is not a valid type - _ = _builder.Value; + _ = _builder.Value; } private static Func GenerateClientBuilder() @@ -151,7 +150,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal for (var i = 0; i < paramTypes.Length; i++) { generator.Emit(OpCodes.Ldloc_0); // Object array loaded - generator.Emit(OpCodes.Ldc_I4, i); + generator.Emit(OpCodes.Ldc_I4, i); generator.Emit(OpCodes.Ldarg, i + 1); // i + 1 generator.Emit(OpCodes.Box, paramTypes[i]); generator.Emit(OpCodes.Stelem_Ref); @@ -161,13 +160,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal generator.Emit(OpCodes.Ldloc_0); generator.Emit(OpCodes.Callvirt, invokeMethod); - if (interfaceMethodInfo.ReturnType == typeof(void)) - { - // void return - generator.Emit(OpCodes.Pop); - } - - generator.Emit(OpCodes.Ret); // Return + generator.Emit(OpCodes.Ret); // Return the Task returned by 'invokeMethod' } private static void VerifyInterface(Type interfaceType) @@ -184,7 +177,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal if (interfaceType.GetEvents().Length != 0) { - throw new InvalidOperationException("Type can not contain events."); + throw new InvalidOperationException("Type must not contain events."); } foreach (var method in interfaceType.GetMethods()) @@ -200,28 +193,26 @@ namespace Microsoft.AspNetCore.SignalR.Internal private static void VerifyMethod(Type interfaceType, MethodInfo interfaceMethod) { - if (interfaceMethod.ReturnType != typeof(void) && interfaceMethod.ReturnType != typeof(Task)) + if (interfaceMethod.ReturnType != typeof(Task)) { - throw new InvalidOperationException("Method must return Void or Task."); + throw new InvalidOperationException( + $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}'."); } foreach (var parameter in interfaceMethod.GetParameters()) { - VerifyParameter(interfaceType, interfaceMethod, parameter); - } - } + if (parameter.IsOut) + { + throw new InvalidOperationException( + $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. Client proxy methods must not have 'out' parameters."); + } - private static void VerifyParameter(Type interfaceType, MethodInfo interfaceMethod, ParameterInfo parameter) - { - if (parameter.IsOut) - { - throw new InvalidOperationException("Method must not take out parameters."); - } - - if (parameter.ParameterType.IsByRef) - { - throw new InvalidOperationException("Method must not take reference parameters."); + if (parameter.ParameterType.IsByRef) + { + throw new InvalidOperationException( + $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. Client proxy methods must not have 'ref' parameters."); + } } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 2ac5a70512..9f6189d785 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -484,6 +484,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + public class SimpleVoidReturningTypedHub : Hub + { + public override Task OnConnectedAsync() + { + // Derefernce Clients, to force initialization of the TypedHubClient + Clients.All.Send("herp"); + return Task.CompletedTask; + } + } + public class SimpleTypedHub : Hub { public override async Task OnConnectedAsync() @@ -534,6 +544,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests Task Send(string message); } + public interface IVoidReturningTypedHubClient + { + void Send(string message); + } + public class ConnectionLifetimeHub : Hub { private readonly ConnectionLifetimeState _state; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index e3068ffd0e..ec8119b19c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -143,6 +143,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests await context.Clients.All.Send("test"); } + [Fact] + public void FailsToLoadInvalidTypedHubClient() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); + var ex = Assert.Throws(() => + serviceProvider.GetRequiredService>()); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidReturningTypedHubClient).FullName}.{nameof(IVoidReturningTypedHubClient.Send)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + } + [Fact] public async Task HandshakeFailureFromUnknownProtocolSendsResponseWithError() { @@ -958,6 +967,22 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task FailsToInitializeInvalidTypedHub() + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(SimpleVoidReturningTypedHub)); + + using (var firstClient = new TestClient()) + { + // ConnectAsync returns a Task and it's the INNER Task that will be faulted. + var connectionTask = await firstClient.ConnectAsync(connectionHandler); + + // We should get a close frame now + var close = Assert.IsType(await firstClient.ReadAsync()); + Assert.Equal("Connection closed with an error.", close.Error); + } + } + [Theory] [MemberData(nameof(HubTypes))] public async Task SendToAllExcept(Type hubType) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/TypedClientBuilderTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/TypedClientBuilderTests.cs new file mode 100644 index 0000000000..2396ffcb12 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/TypedClientBuilderTests.cs @@ -0,0 +1,223 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal +{ + public class TypedClientBuilderTests + { + [Fact] + public async Task ProducesImplementationThatProxiesMethodsToIClientProxyAsync() + { + var clientProxy = new MockProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var task = typedProxy.Method("foo", 42, objArg); + Assert.False(task.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send => + { + Assert.Equal("Method", send.Method); + Assert.Equal("foo", send.Arguments[0]); + Assert.Equal(42, send.Arguments[1]); + Assert.Same(objArg, send.Arguments[2]); + send.Complete(); + }); + + await task.OrTimeout(); + } + + [Fact] + public async Task SupportsSubInterfaces() + { + var clientProxy = new MockProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var task1 = typedProxy.Method("foo", 42, objArg); + Assert.False(task1.IsCompleted); + + var task2 = typedProxy.SubMethod("bar"); + Assert.False(task2.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send1 => + { + Assert.Equal("Method", send1.Method); + Assert.Collection(send1.Arguments, + arg1 => Assert.Equal("foo", arg1), + arg2 => Assert.Equal(42, arg2), + arg3 => Assert.Same(objArg, arg3)); + send1.Complete(); + }, + send2 => + { + Assert.Equal("SubMethod", send2.Method); + Assert.Collection(send2.Arguments, + arg1 => Assert.Equal("bar", arg1)); + send2.Complete(); + }); + + await task1.OrTimeout(); + await task2.OrTimeout(); + } + + [Fact] + public void ThrowsIfProvidedAClass() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal("Type must be an interface.", ex.Message); + } + + [Fact] + public void ThrowsIfProvidedAStruct() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal("Type must be an interface.", ex.Message); + } + + [Fact] + public void ThrowsIfProvidedADelegate() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal("Type must be an interface.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceHasVoidReturningMethod() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidMethodClient).FullName}.{nameof(IVoidMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceHasNonTaskReturns() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IStringMethodClient).FullName}.{nameof(IStringMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceMethodHasOutParam() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal( + $"Cannot generate proxy implementation for '{typeof(IOutParamMethodClient).FullName}.{nameof(IOutParamMethodClient.Method)}'. Client proxy methods must not have 'out' parameters.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceMethodHasRefParam() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal( + $"Cannot generate proxy implementation for '{typeof(IRefParamMethodClient).FullName}.{nameof(IRefParamMethodClient.Method)}'. Client proxy methods must not have 'ref' parameters.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceHasProperties() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal("Type must not contain properties.", ex.Message); + } + + [Fact] + public void ThrowsIfInterfaceHasEvents() + { + var clientProxy = new MockProxy(); + var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); + Assert.Equal("Type must not contain events.", ex.Message); + } + + public interface ITestClient + { + Task Method(string arg1, int arg2, object arg3); + } + + public interface IVoidMethodClient + { + void Method(string arg1, int arg2, object arg3); + } + + public interface IStringMethodClient + { + string Method(string arg1, int arg2, object arg3); + } + + public interface IOutParamMethodClient + { + Task Method(out string arg1); + } + + public interface IRefParamMethodClient + { + Task Method(ref string arg1); + } + + public interface IInheritedClient : ITestClient + { + Task SubMethod(string foo); + } + + public interface IPropertiesClient + { + string Property { get; } + } + + public interface IEventsClient + { + event EventHandler Event; + } + + private class MockProxy : IClientProxy + { + public IList Sends { get; } = new List(); + + public Task SendCoreAsync(string method, object[] args) + { + var tcs = new TaskCompletionSource(); + + Sends.Add(new SendContext(method, args, tcs)); + + return tcs.Task; + } + } + + private struct SendContext + { + private TaskCompletionSource _tcs; + + public string Method { get; } + public object[] Arguments { get; } + + public SendContext(string method, object[] arguments, TaskCompletionSource tcs) : this() + { + Method = method; + Arguments = arguments; + _tcs = tcs; + } + + public void Complete() + { + _tcs.TrySetResult(null); + } + } + } +}