diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index a29bf360d5..78bc3089dd 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, nameof(MethodHub.TaskValueMethod)); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); // json serializer makes this a long Assert.Equal(42L, result.Result); @@ -145,7 +145,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests // kill the connection connectionWrapper.Connection.Dispose(); - await endPointTask; + await endPointTask.OrTimeout(); } } @@ -164,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, "ValueMethod"); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); // json serializer makes this a long Assert.Equal(43L, result.Result); @@ -172,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests // kill the connection connectionWrapper.Connection.Dispose(); - await endPointTask; + await endPointTask.OrTimeout(); } } @@ -191,14 +191,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, "StaticMethod"); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); Assert.Equal("fromStatic", result.Result); // kill the connection connectionWrapper.Connection.Dispose(); - await endPointTask; + await endPointTask.OrTimeout(); } } @@ -217,14 +217,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, "VoidMethod"); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); Assert.Null(result.Result); // kill the connection connectionWrapper.Connection.Dispose(); - await endPointTask; + await endPointTask.OrTimeout(); } } @@ -243,7 +243,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, "ConcatString", (byte)32, 42, 'm', "string"); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); Assert.Equal("32, 42, m, string", result.Result); // kill the connection @@ -268,7 +268,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest(connectionWrapper, adapter, "OnDisconnectedAsync"); - var result = await ReadConnectionOutputAsync(connectionWrapper); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error); } @@ -287,6 +287,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + await Task.WhenAll(firstConnection.Connected, secondConnection.Connected).OrTimeout(); + var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -294,7 +296,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests foreach (var result in await Task.WhenAll( ReadConnectionOutputAsync(firstConnection), - ReadConnectionOutputAsync(secondConnection))) + ReadConnectionOutputAsync(secondConnection)).OrTimeout()) { Assert.Equal("Broadcast", result.Method); Assert.Equal(1, result.Arguments.Length); @@ -305,7 +307,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests firstConnection.Connection.Dispose(); secondConnection.Connection.Dispose(); - await Task.WhenAll(firstEndPointTask, secondEndPointTask); + await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } @@ -322,6 +324,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + await Task.WhenAll(firstConnection.Connected, secondConnection.Connected).OrTimeout(); + var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -338,7 +342,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.False(firstConnection.Application.Input.TryRead(out message)); // check that 'secondConnection' has received the group send - var res = await ReadConnectionOutputAsync(secondConnection); + var res = await ReadConnectionOutputAsync(secondConnection).OrTimeout(); Assert.Equal("Send", res.Method); Assert.Equal(1, res.Arguments.Length); Assert.Equal("test", res.Arguments[0]); @@ -347,7 +351,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests firstConnection.Connection.Dispose(); secondConnection.Connection.Dispose(); - await Task.WhenAll(firstEndPointTask, secondEndPointTask); + await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } @@ -370,7 +374,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests // kill the connection connection.Connection.Dispose(); - await endPointTask; + await endPointTask.OrTimeout(); } } @@ -387,13 +391,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + await Task.WhenAll(firstConnection.Connected, secondConnection.Connected).OrTimeout(); + var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest_IgnoreReceive(firstConnection, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test"); // check that 'secondConnection' has received the group send - var res = await ReadConnectionOutputAsync(secondConnection); + var res = await ReadConnectionOutputAsync(secondConnection).OrTimeout(); Assert.Equal("Send", res.Method); Assert.Equal(1, res.Arguments.Length); Assert.Equal("test", res.Arguments[0]); @@ -402,7 +408,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests firstConnection.Connection.Dispose(); secondConnection.Connection.Dispose(); - await Task.WhenAll(firstEndPointTask, secondEndPointTask); + await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } @@ -419,13 +425,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); + await Task.WhenAll(firstConnection.Connected, secondConnection.Connected).OrTimeout(); + var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest_IgnoreReceive(firstConnection, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test"); // check that 'secondConnection' has received the group send - var result = await ReadConnectionOutputAsync(secondConnection); + var result = await ReadConnectionOutputAsync(secondConnection).OrTimeout(); Assert.Equal("Send", result.Method); Assert.Equal(1, result.Arguments.Length); Assert.Equal("test", result.Arguments[0]); @@ -434,7 +442,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests firstConnection.Connection.Dispose(); secondConnection.Connection.Dispose(); - await Task.WhenAll(firstEndPointTask, secondEndPointTask); + await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); } } @@ -518,6 +526,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests private class MethodHub : Hub { + public override Task OnConnectedAsync() + { + Context.Connection.Metadata.Get>("ConnectedTask").SetResult(true); + return base.OnConnectedAsync(); + } + public Task GroupRemoveMethod(string groupName) { return Groups.RemoveAsync(groupName); @@ -604,11 +618,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests public class ConnectionWrapper : IDisposable { private static int _id; - + public Connection Connection { get; } public IChannelConnection Application { get; } + public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; + public ConnectionWrapper(string format = "json") { var transportToApplication = Channel.CreateUnbounded(); @@ -620,6 +636,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests Connection = new Connection(Guid.NewGuid().ToString(), transport); Connection.Metadata["formatType"] = format; Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); + + Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); } public void Dispose() diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TaskExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TaskExtensions.cs new file mode 100644 index 0000000000..a4e3ad62bb --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TaskExtensions.cs @@ -0,0 +1,42 @@ +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public static class TaskExtensions + { + private const int DefaultTimeout = 5000; + + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout) + { + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds)); + } + + public static async Task OrTimeout(this Task task, TimeSpan timeout) + { + var completed = await Task.WhenAny(task, Task.Delay(timeout)); + if (completed != task) + { + throw new TimeoutException(); + } + + await task; + } + + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout) + { + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds)); + } + + public static async Task OrTimeout(this Task task, TimeSpan timeout) + { + var completed = await Task.WhenAny(task, Task.Delay(timeout)); + if (completed != task) + { + throw new TimeoutException(); + } + + return await task; + } + } +}