// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Moq; using Xunit; using Microsoft.AspNetCore.SignalR.Tests.Common; namespace Microsoft.AspNetCore.SignalR.Tests { public class HubEndpointTests { [Fact] public async Task HubsAreDisposed() { var trackDispose = new TrackDispose(); var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose)); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); // kill the connection client.Dispose(); await endPointTask; Assert.Equal(2, trackDispose.DisposeCount); } } [Fact] public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows() { var mockLifetimeManager = new Mock>(); mockLifetimeManager .Setup(m => m.OnConnectedAsync(It.IsAny())) .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); var mockHubActivator = new Mock>(); var serviceProvider = CreateServiceProvider(services => { services.AddSingleton(mockLifetimeManager.Object); services.AddSingleton(mockHubActivator.Object); }); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var exception = await Assert.ThrowsAsync( async () => await endPoint.OnConnectedAsync(client.Connection)); Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message); client.Dispose(); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); // No hubs should be created since the connection is terminated mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny()), Times.Never); } } [Fact] public async Task HubOnDisconnectedAsyncCalledIfHubOnConnectedAsyncThrows() { var mockLifetimeManager = new Mock>(); var serviceProvider = CreateServiceProvider(services => { services.AddSingleton(mockLifetimeManager.Object); }); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); client.Dispose(); var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnConnected failed.", exception.Message); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } [Fact] public async Task LifetimeManagerOnDisconnectedAsyncCalledIfHubOnDisconnectedAsyncThrows() { var mockLifetimeManager = new Mock>(); var serviceProvider = CreateServiceProvider(services => { services.AddSingleton(mockLifetimeManager.Object); }); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); client.Dispose(); var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnDisconnected failed.", exception.Message); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } [Fact] public async Task HubMethodCanReturnValueFromTask() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.TaskValueMethod)).OrTimeout(); // json serializer makes this a long Assert.Equal(42L, result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubMethodsAreCaseInsensitive() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke("echo", "hello").OrTimeout(); Assert.Null(result.Error); Assert.Equal("hello", result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubMethodCanReturnValue() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.ValueMethod)).OrTimeout(); // json serializer makes this a long Assert.Equal(43L, result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubMethodCanBeStatic() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.StaticMethod)).OrTimeout(); Assert.Equal("fromStatic", result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubMethodCanBeVoid() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.VoidMethod)).OrTimeout(); Assert.Null(result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubMethodWithMultiParam() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.ConcatString), (byte)32, 42, 'm', "string").OrTimeout(); Assert.Equal("32, 42, m, string", result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task CanCallInheritedHubMethodFromInheritingHub() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(InheritedHub.BaseMethod), "string").OrTimeout(); Assert.Equal("string", result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task CanCallOverridenVirtualHubMethod() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(InheritedHub.VirtualMethod), 10).OrTimeout(); Assert.Equal(0L, result.Result); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task CannotCallOverriddenBaseHubMethod() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); var result = await client.Invoke(nameof(MethodHub.OnDisconnectedAsync)).OrTimeout(); Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public void HubsCannotHaveOverloadedMethods() { var serviceProvider = CreateServiceProvider(); try { var endPoint = serviceProvider.GetService>(); Assert.True(false); } catch (NotSupportedException ex) { Assert.Equal("Duplicate definitions of 'OverloadedMethod'. Overloading is not supported.", ex.Message); } } [Fact] public async Task BroadcastHubMethod_SendsToAllClients() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var firstClient = new TestClient(serviceProvider)) using (var secondClient = new TestClient(serviceProvider)) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); await firstClient.Invoke(nameof(MethodHub.BroadcastMethod), "test").OrTimeout(); foreach (var result in await Task.WhenAll( firstClient.Read(), secondClient.Read()).OrTimeout()) { Assert.Equal("Broadcast", result.Method); Assert.Equal(1, result.Arguments.Length); Assert.Equal("test", result.Arguments[0]); } // kill the connections firstClient.Dispose(); secondClient.Dispose(); await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } [Fact] public async Task HubsCanAddAndSendToGroup() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var firstClient = new TestClient(serviceProvider)) using (var secondClient = new TestClient(serviceProvider)) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); var result = await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout(); // check that 'firstConnection' hasn't received the group send Assert.Null(result.Id); // check that 'secondConnection' hasn't received the group send Assert.Null(await secondClient.TryRead().OrTimeout()); result = await secondClient.Invoke(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout(); Assert.Null(result.Id); await firstClient.Invoke(nameof(MethodHub.GroupSendMethod), "testGroup", "test").OrTimeout(); // check that 'secondConnection' has received the group send var descriptor = await secondClient.Read().OrTimeout(); Assert.Equal("Send", descriptor.Method); Assert.Equal(1, descriptor.Arguments.Length); Assert.Equal("test", descriptor.Arguments[0]); // kill the connections firstClient.Dispose(); secondClient.Dispose(); await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } [Fact] public async Task RemoveFromGroupWhenNotInGroupDoesNotFail() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var client = new TestClient(serviceProvider)) { var endPointTask = endPoint.OnConnectedAsync(client.Connection); await client.Invoke(nameof(MethodHub.GroupRemoveMethod), "testGroup").OrTimeout(); // kill the connection client.Dispose(); await endPointTask.OrTimeout(); } } [Fact] public async Task HubsCanSendToUser() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var firstClient = new TestClient(serviceProvider)) using (var secondClient = new TestClient(serviceProvider)) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); await firstClient.Invoke(nameof(MethodHub.ClientSendMethod), secondClient.Connection.User.Identity.Name, "test").OrTimeout(); // check that 'secondConnection' has received the group send var result = await secondClient.Read().OrTimeout(); Assert.Equal("Send", result.Method); Assert.Equal(1, result.Arguments.Length); Assert.Equal("test", result.Arguments[0]); // kill the connections firstClient.Dispose(); secondClient.Dispose(); await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } [Fact] public async Task HubsCanSendToConnection() { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); using (var firstClient = new TestClient(serviceProvider)) using (var secondClient = new TestClient(serviceProvider)) { var firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); await firstClient.Invoke(nameof(MethodHub.ConnectionSendMethod), secondClient.Connection.ConnectionId, "test").OrTimeout(); // check that 'secondConnection' has received the group send var result = await secondClient.Read().OrTimeout(); Assert.Equal("Send", result.Method); Assert.Equal(1, result.Arguments.Length); Assert.Equal("test", result.Arguments[0]); // kill the connections firstClient.Dispose(); secondClient.Dispose(); await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } private static Type GetEndPointType(Type hubType) { var endPointType = typeof(HubEndPoint<>); return endPointType.MakeGenericType(hubType); } private static Type GetGenericType(Type genericType, Type hubType) { return genericType.MakeGenericType(hubType); } private IServiceProvider CreateServiceProvider(Action addServices = null) { var services = new ServiceCollection(); services.AddOptions() .AddLogging() .AddSignalR(); addServices?.Invoke(services); return services.BuildServiceProvider(); } public class OnConnectedThrowsHub : Hub { public override Task OnConnectedAsync() { var tcs = new TaskCompletionSource(); tcs.SetException(new InvalidOperationException("Hub OnConnected failed.")); return tcs.Task; } } public class OnDisconnectedThrowsHub : Hub { public override Task OnDisconnectedAsync(Exception exception) { var tcs = new TaskCompletionSource(); tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed.")); return tcs.Task; } } private class MethodHub : Hub { public override Task OnConnectedAsync() { Context.Connection.Metadata.Get>("ConnectedTask").SetResult(true); return base.OnConnectedAsync(); } public Task GroupRemoveMethod(string groupName) { return Groups.RemoveAsync(groupName); } public Task ClientSendMethod(string userId, string message) { return Clients.User(userId).InvokeAsync("Send", message); } public Task ConnectionSendMethod(string connectionId, string message) { return Clients.Client(connectionId).InvokeAsync("Send", message); } public Task GroupAddMethod(string groupName) { return Groups.AddAsync(groupName); } public Task GroupSendMethod(string groupName, string message) { return Clients.Group(groupName).InvokeAsync("Send", message); } public Task BroadcastMethod(string message) { return Clients.All.InvokeAsync("Broadcast", message); } public Task TaskValueMethod() { return Task.FromResult(42); } public int ValueMethod() { return 43; } public string Echo(string data) { return data; } static public string StaticMethod() { return "fromStatic"; } public void VoidMethod() { } public string ConcatString(byte b, int i, char c, string s) { return $"{b}, {i}, {c}, {s}"; } public override Task OnDisconnectedAsync(Exception e) { return TaskCache.CompletedTask; } } private class InheritedHub : BaseHub { public override int VirtualMethod(int num) { return num - 10; } } private class BaseHub : Hub { public string BaseMethod(string message) { return message; } public virtual int VirtualMethod(int num) { return num; } } private class InvalidHub : Hub { public void OverloadedMethod(int num) { } public void OverloadedMethod(string message) { } } private class TestHub : Hub { private TrackDispose _trackDispose; public TestHub(TrackDispose trackDispose) { _trackDispose = trackDispose; } protected override void Dispose(bool dispose) { if (dispose) { _trackDispose.DisposeCount++; } } } private class TrackDispose { public int DisposeCount = 0; } } }