// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Diagnostics; using System.IO.Pipelines; using System.Runtime.CompilerServices; namespace System.Threading.Tasks { public static class TaskExtensions { private const int DefaultTimeout = 5000; public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); } public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { if (task.IsCompleted) { await task; return; } var cts = new CancellationTokenSource(); var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } cts.Cancel(); await task; } public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); } public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { if (task.IsCompleted) { return await task; } var cts = new CancellationTokenSource(); var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } cts.Cancel(); return await task; } public static async Task OrThrowIfOtherFails(this Task task, Task otherTask) { var completed = await Task.WhenAny(task, otherTask); if(completed == otherTask && otherTask.IsFaulted) { // Manifest the exception otherTask.GetAwaiter().GetResult(); throw new Exception("Unreachable code"); } else { // Await the task we were asked to await. Either it's finished, or the otherTask finished successfully, and it's not our job to check that await task; } } public static async Task OrThrowIfOtherFails(this Task task, Task otherTask) { await OrThrowIfOtherFails((Task)task, otherTask); // If we get here, 'task' is finished and succeeded. return task.GetAwaiter().GetResult(); } private static string GetMessage(string memberName, string filePath, int? lineNumber) { if (!string.IsNullOrEmpty(memberName)) { return $"Operation in {memberName} timed out at {filePath}:{lineNumber}"; } else { return "Operation timed out"; } } } }