diff --git a/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs b/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs index e150ee6c5e..8c7a1a052d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/JsonNetInvocationAdapter.cs @@ -1,9 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.IO; -using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Internal; diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 38ddfbe662..50cd69fcb0 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -36,6 +36,11 @@ namespace Microsoft.AspNetCore.SignalR { var groups = connection.Metadata.Get>("groups"); + if (groups == null) + { + return TaskCache.CompletedTask; + } + lock (groups) { groups.Remove(groupName); @@ -102,7 +107,7 @@ namespace Microsoft.AspNetCore.SignalR { return InvokeAllWhere(methodName, args, connection => { - return connection.User.Identity.Name == userId; + return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal); }); } diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index bfa74fc2d0..b0c184bfba 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -262,10 +262,22 @@ namespace Microsoft.AspNetCore.SignalR }; } - private static bool IsHubMethod(MethodInfo m) + private static bool IsHubMethod(MethodInfo methodInfo) { // TODO: Add more checks - return m.IsPublic && !m.IsSpecialName; + if (!methodInfo.IsPublic || methodInfo.IsSpecialName) + { + return false; + } + + var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType; + var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition; + if (typeof(Hub<>) == baseType) + { + return false; + } + + return true; } Type IInvocationBinder.GetReturnType(string invocationId) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 1cf03fdc47..e7eca16802 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -2,14 +2,17 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.IO.Pipelines; using System.Security.Claims; -using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Internal; using Moq; +using Newtonsoft.Json; using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests @@ -58,17 +61,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests await connectionWrapper.ApplicationStartedReading; - var buffer = connectionWrapper.ConnectionState.Application.Output.Alloc(); - buffer.Write(Encoding.UTF8.GetBytes("0xdeadbeef")); - await buffer.FlushAsync(); + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + await SendRequest(connectionWrapper.Connection.Transport, adapter, "0xdeadbeef"); connectionWrapper.Dispose(); - // InvalidCastException because the payload is not a JObject - // which is expected by the formatter - await Assert.ThrowsAsync(async () => await endPointTask); + await Assert.ThrowsAsync(async () => await endPointTask); - Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull()), Times.Once()); + Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull()), Times.Once()); } } @@ -185,6 +186,407 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task HubMethodCanReturnValueFromTask() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "TaskValueMethod"); + var res = await ReadConnectionOutputAsync(connectionWrapper.Connection.Transport); + // json serializer makes this a long + Assert.Equal(42L, res.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task HubMethodCanReturnValue() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "ValueMethod"); + var res = await ReadConnectionOutputAsync(connectionWrapper.Connection.Transport); + // json serializer makes this a long + Assert.Equal(43L, res.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task HubMethodCanBeStatic() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "StaticMethod"); + var res = await ReadConnectionOutputAsync(connectionWrapper.Connection.Transport); + Assert.Equal("fromStatic", res.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task HubMethodCanBeVoid() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "VoidMethod"); + var res = await ReadConnectionOutputAsync(connectionWrapper.Connection.Transport); + Assert.Equal(null, res.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task HubMethodWithMultiParam() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "ConcatString", (byte)32, 42, 'm', "string"); + var res = await ReadConnectionOutputAsync(connectionWrapper.Connection.Transport); + Assert.Equal("32, 42, m, string", res.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task CannotCallOverriddenBaseHubMethod() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + await connectionWrapper.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper.Connection.Transport, adapter, "OnDisconnectedAsync"); + + try + { + await endPointTask; + Assert.True(false); + } + catch (InvalidOperationException ex) + { + Assert.Equal("The hub method 'OnDisconnectedAsync' could not be resolved.", ex.Message); + } + } + } + + [Fact] + public async Task BroadcastHubMethod_SendsToAllClients() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var firstConnection = new ConnectionWrapper()) + using (var secondConnection = new ConnectionWrapper()) + { + var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); + var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + + await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(firstConnection.Connection.Transport, adapter, "BroadcastMethod", "test"); + + foreach (var res in await Task.WhenAll( + ReadConnectionOutputAsync(firstConnection.Connection.Transport), + ReadConnectionOutputAsync(secondConnection.Connection.Transport))) + { + Assert.Equal("Broadcast", res.Method); + Assert.Equal(1, res.Arguments.Length); + Assert.Equal("test", res.Arguments[0]); + } + + // kill the connections + firstConnection.Connection.Dispose(); + secondConnection.Connection.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask); + } + } + + [Fact] + public async Task HubsCanAddAndSendToGroup() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var firstConnection = new ConnectionWrapper()) + using (var secondConnection = new ConnectionWrapper()) + { + var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); + var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + + await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test"); + // check that 'secondConnection' hasn't received the group send + Assert.False(((PipelineReaderWriter)secondConnection.Connection.Transport.Output).ReadAsync().IsCompleted); + + await SendRequest_IgnoreReceive(secondConnection.Connection.Transport, adapter, "GroupAddMethod", "testGroup"); + + await SendRequest(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test"); + // check that 'firstConnection' hasn't received the group send + Assert.False(((PipelineReaderWriter)firstConnection.Connection.Transport.Output).ReadAsync().IsCompleted); + + // check that 'secondConnection' has received the group send + var res = await ReadConnectionOutputAsync(secondConnection.Connection.Transport); + Assert.Equal("Send", res.Method); + Assert.Equal(1, res.Arguments.Length); + Assert.Equal("test", res.Arguments[0]); + + // kill the connections + firstConnection.Connection.Dispose(); + secondConnection.Connection.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask); + } + } + + [Fact] + public async Task RemoveFromGroupWhenNotInGroupDoesNotFail() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connection = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connection.Connection); + + await connection.ApplicationStartedReading; + + var invocationAdapter = serviceProvider.GetService(); + var writer = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest_IgnoreReceive(connection.Connection.Transport, writer, "GroupRemoveMethod", "testGroup"); + + // kill the connection + connection.Connection.Dispose(); + + await endPointTask; + } + } + + [Fact] + public async Task HubsCanSendToUser() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var firstConnection = new ConnectionWrapper()) + using (var secondConnection = new ConnectionWrapper()) + { + var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); + var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + + await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test"); + + // check that 'secondConnection' has received the group send + var res = await ReadConnectionOutputAsync(secondConnection.Connection.Transport); + Assert.Equal("Send", res.Method); + Assert.Equal(1, res.Arguments.Length); + Assert.Equal("test", res.Arguments[0]); + + // kill the connections + firstConnection.Connection.Dispose(); + secondConnection.Connection.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask); + } + } + + [Fact] + public async Task HubsCanSendToConnection() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var firstConnection = new ConnectionWrapper()) + using (var secondConnection = new ConnectionWrapper()) + { + var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); + var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + + await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test"); + + // check that 'secondConnection' has received the group send + var res = await ReadConnectionOutputAsync(secondConnection.Connection.Transport); + Assert.Equal("Send", res.Method); + Assert.Equal(1, res.Arguments.Length); + Assert.Equal("test", res.Arguments[0]); + + // kill the connections + firstConnection.Connection.Dispose(); + secondConnection.Connection.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask); + } + } + + private class MethodHub : Hub + { + public Task GroupRemoveMethod(string groupName) + { + return Groups.RemoveAsync(groupName); + } + + public Task ClientSendMethod(string userId, string message) + { + return Clients.User(userId).InvokeAsync("Send", message); + } + + public Task ConnectionSendMethod(string connectionId, string message) + { + return Clients.Client(connectionId).InvokeAsync("Send", message); + } + + public Task GroupAddMethod(string groupName) + { + return Groups.AddAsync(groupName); + } + + public Task GroupSendMethod(string groupName, string message) + { + return Clients.Group(groupName).InvokeAsync("Send", message); + } + + public Task BroadcastMethod(string message) + { + return Clients.All.InvokeAsync("Broadcast", message); + } + + public Task TaskValueMethod() + { + return Task.FromResult(42); + } + + public int ValueMethod() + { + return 43; + } + + static public string StaticMethod() + { + return "fromStatic"; + } + + public void VoidMethod() + { + } + + public string ConcatString(byte b, int i, char c, string s) + { + return $"{b}, {i}, {c}, {s}"; + } + + public override Task OnDisconnectedAsync(Exception e) + { + return TaskCache.CompletedTask; + } + } + private class TestHub : Hub { private TrackDispose _trackDispose; @@ -208,6 +610,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests public int DisposeCount = 0; } + public async Task SendRequest(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args) + { + if (connection == null) + { + throw new ArgumentNullException(); + } + + var stream = new MemoryStream(); + await writer.WriteMessageAsync(new InvocationDescriptor + { + Arguments = args, + Method = method + }, stream); + + var buffer = ((PipelineReaderWriter)connection.Input).Alloc(); + buffer.Write(stream.ToArray()); + await buffer.FlushAsync(); + } + + public async Task SendRequest_IgnoreReceive(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args) + { + await SendRequest(connection, writer, method, args); + + var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync(); + ((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End); + } + + private async Task ReadConnectionOutputAsync(IPipelineConnection connection) + { + // TODO: other formats? + var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync(); + var serializer = new JsonSerializer(); + var res = serializer.Deserialize(new JsonTextReader(new StreamReader(new MemoryStream(methodResult.Buffer.ToArray())))); + ((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End); + + return res; + } + private IServiceProvider CreateServiceProvider(Action addServices = null) { var services = new ServiceCollection(); @@ -222,6 +662,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests private class ConnectionWrapper : IDisposable { + private static int ID; private PipelineFactory _factory; public StreamingConnectionState ConnectionState; @@ -239,7 +680,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests ConnectionState = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming); ConnectionState.Connection.Metadata["formatType"] = format; - ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity()); + ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref ID).ToString()) })); } public void Dispose() diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/project.json b/test/Microsoft.AspNetCore.SignalR.Tests/project.json index d868797cff..d7c4dcf30c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/project.json +++ b/test/Microsoft.AspNetCore.SignalR.Tests/project.json @@ -10,6 +10,10 @@ "Microsoft.AspNetCore.SignalR": "1.0.0-*", "Microsoft.Extensions.DependencyInjection": "1.2.0-*", "Microsoft.Extensions.Logging": "1.2.0-*", + "Microsoft.Extensions.TaskCache.Sources": { + "version": "1.2.0-*", + "type": "build" + }, "Moq": "4.6.36-*", "xunit": "2.2.0-*" },