From a93839e1b21c863b0d4d0f6d3a5892d7262c00c4 Mon Sep 17 00:00:00 2001 From: moozzyk Date: Mon, 6 Mar 2017 10:20:57 -0800 Subject: [PATCH] Adding TaskQueue --- .../TaskQueue.cs | 70 +++++++++++++++++++ .../TaskQueueTests.cs | 56 +++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 src/Microsoft.AspNetCore.Sockets.Client/TaskQueue.cs create mode 100644 test/Microsoft.AspNetCore.Sockets.Client.Tests/TaskQueueTests.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Client/TaskQueue.cs b/src/Microsoft.AspNetCore.Sockets.Client/TaskQueue.cs new file mode 100644 index 0000000000..7b78d390ff --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client/TaskQueue.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets.Client.Internal +{ + // Allows serial queuing of Task instances + // The tasks are not called on the current synchronization context + + public sealed class TaskQueue + { + private readonly object _lockObj = new object(); + private Task _lastQueuedTask; + private volatile bool _drained; + + public TaskQueue() + : this(Task.CompletedTask) + { } + + public TaskQueue(Task initialTask) + { + _lastQueuedTask = initialTask; + } + + public bool IsDrained + { + get { return _drained; } + } + + public Task Enqueue(Func taskFunc) + { + return Enqueue(s => taskFunc(), null); + } + + public Task Enqueue(Func taskFunc, object state) + { + lock (_lockObj) + { + if (_drained) + { + return _lastQueuedTask; + } + + var newTask = _lastQueuedTask.ContinueWith((t, s1) => + { + if (t.IsFaulted || t.IsCanceled) + { + return t; + } + return taskFunc(s1); + }, + state).Unwrap(); + _lastQueuedTask = newTask; + return newTask; + } + } + + public Task Drain() + { + lock (_lockObj) + { + _drained = true; + + return _lastQueuedTask; + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/TaskQueueTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/TaskQueueTests.cs new file mode 100644 index 0000000000..7ae8dfce67 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/TaskQueueTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets.Client.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.Client.Tests +{ + public class TaskQueueTests + { + [Fact] + public async Task DrainingTaskQueueShutsQueueOff() + { + var queue = new TaskQueue(); + await queue.Enqueue(() => Task.CompletedTask); + await queue.Drain(); + + // This would throw if the task was queued successfully + await queue.Enqueue(() => Task.FromException(new Exception())); + } + + [Fact] + public async Task TaskQueueDoesNotQueueNewTasksIfPreviousTaskFaulted() + { + var exception = new Exception(); + var queue = new TaskQueue(); + var ignore = queue.Enqueue(() => Task.FromException(exception)); + var task = queue.Enqueue(() => Task.CompletedTask); + + var actual = await Assert.ThrowsAsync(async () => await task); + Assert.Same(exception, actual); + } + + [Fact] + public void TaskQueueRunsTasksInSequence() + { + var queue = new TaskQueue(); + int n = 0; + queue.Enqueue(() => + { + n++; + return Task.CompletedTask; + }); + + Task task = queue.Enqueue(() => + { + return Task.Delay(100).ContinueWith(t => n++); + }); + + task.Wait(); + Assert.Equal(n, 2); + } + } +}