diff --git a/test/Common/TaskExtensions.cs b/test/Common/TaskExtensions.cs deleted file mode 100644 index 4621def836..0000000000 --- a/test/Common/TaskExtensions.cs +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR.Tests.Common -{ - public static class TaskExtensions - { - private const int DefaultTimeout = 5000; - - public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) - { - return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); - } - - public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) - { - var cts = new CancellationTokenSource(); - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); - if (completed != task) - { - throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); - } - cts.Cancel(); - - await task; - } - - public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) - { - return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); - } - - public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) - { - var cts = new CancellationTokenSource(); - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); - if (completed != task) - { - throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); - } - cts.Cancel(); - - return await task; - } - - private static string GetMessage(string memberName, string filePath, int? lineNumber) - { - if (!string.IsNullOrEmpty(memberName)) - { - return $"Operation in {memberName} timed out at {filePath}:{lineNumber}"; - } - else - { - return "Operation timed out"; - } - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs index 65a333e142..9fd7057005 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs @@ -17,6 +17,12 @@ namespace System.Threading.Tasks public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { + if (task.IsCompleted) + { + await task; + return; + } + var cts = new CancellationTokenSource(); var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) @@ -35,6 +41,11 @@ namespace System.Threading.Tasks public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { + if (task.IsCompleted) + { + return await task; + } + var cts = new CancellationTokenSource(); var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 3fa1dd8307..089e11b8fe 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -290,7 +290,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Sent message", bytes.Length); logger.LogInformation("Receiving message"); - var receivedData = await receiveTcs.Task.OrTimeout(); + // No timeout here because it can take a while to receive all the bytes + var receivedData = await receiveTcs.Task; Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); logger.LogInformation("Completed receive"); }