Adding TaskQueue
This commit is contained in:
parent
abc9109cf3
commit
a93839e1b2
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Client.Internal
|
||||
{
|
||||
// Allows serial queuing of Task instances
|
||||
// The tasks are not called on the current synchronization context
|
||||
|
||||
public sealed class TaskQueue
|
||||
{
|
||||
private readonly object _lockObj = new object();
|
||||
private Task _lastQueuedTask;
|
||||
private volatile bool _drained;
|
||||
|
||||
public TaskQueue()
|
||||
: this(Task.CompletedTask)
|
||||
{ }
|
||||
|
||||
public TaskQueue(Task initialTask)
|
||||
{
|
||||
_lastQueuedTask = initialTask;
|
||||
}
|
||||
|
||||
public bool IsDrained
|
||||
{
|
||||
get { return _drained; }
|
||||
}
|
||||
|
||||
public Task Enqueue(Func<Task> taskFunc)
|
||||
{
|
||||
return Enqueue(s => taskFunc(), null);
|
||||
}
|
||||
|
||||
public Task Enqueue(Func<object, Task> taskFunc, object state)
|
||||
{
|
||||
lock (_lockObj)
|
||||
{
|
||||
if (_drained)
|
||||
{
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
|
||||
var newTask = _lastQueuedTask.ContinueWith((t, s1) =>
|
||||
{
|
||||
if (t.IsFaulted || t.IsCanceled)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
return taskFunc(s1);
|
||||
},
|
||||
state).Unwrap();
|
||||
_lastQueuedTask = newTask;
|
||||
return newTask;
|
||||
}
|
||||
}
|
||||
|
||||
public Task Drain()
|
||||
{
|
||||
lock (_lockObj)
|
||||
{
|
||||
_drained = true;
|
||||
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets.Client.Internal;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Client.Tests
|
||||
{
|
||||
public class TaskQueueTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task DrainingTaskQueueShutsQueueOff()
|
||||
{
|
||||
var queue = new TaskQueue();
|
||||
await queue.Enqueue(() => Task.CompletedTask);
|
||||
await queue.Drain();
|
||||
|
||||
// This would throw if the task was queued successfully
|
||||
await queue.Enqueue(() => Task.FromException(new Exception()));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TaskQueueDoesNotQueueNewTasksIfPreviousTaskFaulted()
|
||||
{
|
||||
var exception = new Exception();
|
||||
var queue = new TaskQueue();
|
||||
var ignore = queue.Enqueue(() => Task.FromException(exception));
|
||||
var task = queue.Enqueue(() => Task.CompletedTask);
|
||||
|
||||
var actual = await Assert.ThrowsAsync<Exception>(async () => await task);
|
||||
Assert.Same(exception, actual);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TaskQueueRunsTasksInSequence()
|
||||
{
|
||||
var queue = new TaskQueue();
|
||||
int n = 0;
|
||||
queue.Enqueue(() =>
|
||||
{
|
||||
n++;
|
||||
return Task.CompletedTask;
|
||||
});
|
||||
|
||||
Task task = queue.Enqueue(() =>
|
||||
{
|
||||
return Task.Delay(100).ContinueWith(t => n++);
|
||||
});
|
||||
|
||||
task.Wait();
|
||||
Assert.Equal(n, 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue