diff --git a/src/Hosting/Hosting/test/Internal/HostingEventSourceTests.cs b/src/Hosting/Hosting/test/Internal/HostingEventSourceTests.cs index bcf4188a39..47a661eff7 100644 --- a/src/Hosting/Hosting/test/Internal/HostingEventSourceTests.cs +++ b/src/Hosting/Hosting/test/Internal/HostingEventSourceTests.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Diagnostics.Tracing; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Internal; @@ -185,7 +184,7 @@ namespace Microsoft.AspNetCore.Hosting public async Task VerifyCountersFireWithCorrectValues() { // Arrange - var eventListener = new CounterListener(new[] { + var eventListener = new TestCounterListener(new[] { "requests-per-second", "total-requests", "current-requests", @@ -207,6 +206,7 @@ namespace Microsoft.AspNetCore.Hosting { "EventCounterIntervalSec", "1" } }); + // Act & Assert hostingEventSource.RequestStart("GET", "/"); Assert.Equal(1, await totalRequestValues.FirstOrDefault(v => v == 1)); @@ -241,36 +241,5 @@ namespace Microsoft.AspNetCore.Hosting { return new HostingEventSource(Guid.NewGuid().ToString()); } - - private class CounterListener : EventListener - { - private readonly Dictionary> _counters = new Dictionary>(); - - public CounterListener(string[] counterNames) - { - foreach (var item in counterNames) - { - _counters[item] = Channel.CreateUnbounded(); - } - } - - public IAsyncEnumerable GetCounterValues(string counterName, CancellationToken cancellationToken = default) - { - return _counters[counterName].Reader.ReadAllAsync(cancellationToken); - } - - protected override void OnEventWritten(EventWrittenEventArgs eventData) - { - if (eventData.EventName == "EventCounters") - { - var payload = (IDictionary)eventData.Payload[0]; - var counter = (string)payload["Name"]; - payload.TryGetValue("Increment", out var increment); - payload.TryGetValue("Mean", out var mean); - var writer = _counters[counter].Writer; - writer.TryWrite((double)(increment ?? mean)); - } - } - } } } diff --git a/src/Hosting/Hosting/test/Microsoft.AspNetCore.Hosting.Tests.csproj b/src/Hosting/Hosting/test/Microsoft.AspNetCore.Hosting.Tests.csproj index 0254330bbf..4bc7d5cf38 100644 --- a/src/Hosting/Hosting/test/Microsoft.AspNetCore.Hosting.Tests.csproj +++ b/src/Hosting/Hosting/test/Microsoft.AspNetCore.Hosting.Tests.csproj @@ -1,4 +1,4 @@ - + netcoreapp3.0 @@ -7,6 +7,7 @@ + diff --git a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.csproj b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.csproj index 8e8c1dfce6..ba0a06082e 100644 --- a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.csproj +++ b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.csproj @@ -8,5 +8,6 @@ + diff --git a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs index 7fc7d24be7..49f5ed18b7 100644 --- a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs +++ b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs @@ -10,15 +10,9 @@ namespace Microsoft.AspNetCore.Builder } namespace Microsoft.AspNetCore.RequestThrottling { - public partial interface IQueuePolicy - { - void OnExit(); - System.Threading.Tasks.Task TryEnterAsync(); - } public partial class RequestThrottlingMiddleware { - public RequestThrottlingMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.AspNetCore.RequestThrottling.IQueuePolicy queue, Microsoft.Extensions.Options.IOptions options) { } - public int QueuedRequestCount { get { throw null; } } + public RequestThrottlingMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.AspNetCore.RequestThrottling.QueuePolicies.IQueuePolicy queue, Microsoft.Extensions.Options.IOptions options) { } [System.Diagnostics.DebuggerStepThroughAttribute] public System.Threading.Tasks.Task Invoke(Microsoft.AspNetCore.Http.HttpContext context) { throw null; } } @@ -30,6 +24,11 @@ namespace Microsoft.AspNetCore.RequestThrottling } namespace Microsoft.AspNetCore.RequestThrottling.QueuePolicies { + public partial interface IQueuePolicy + { + void OnExit(); + System.Threading.Tasks.Task TryEnterAsync(); + } public partial class QueuePolicyOptions { public QueuePolicyOptions() { } diff --git a/src/Middleware/RequestThrottling/sample/Startup.cs b/src/Middleware/RequestThrottling/sample/Startup.cs index d12dcfa943..846a7cc186 100644 --- a/src/Middleware/RequestThrottling/sample/Startup.cs +++ b/src/Middleware/RequestThrottling/sample/Startup.cs @@ -22,9 +22,9 @@ namespace RequestThrottlingSample // For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940 public void ConfigureServices(IServiceCollection services) { - services.AddTailDropQueue((options) => + services.AddStackQueue((options) => { - options.MaxConcurrentRequests = Math.Max(1, _config.GetValue("maxCores")); + options.MaxConcurrentRequests = Math.Max(1, _config.GetValue("maxConcurrent")); options.RequestQueueLimit = Math.Max(1, _config.GetValue("maxQueue")); }); diff --git a/src/Middleware/RequestThrottling/src/Microsoft.AspNetCore.RequestThrottling.csproj b/src/Middleware/RequestThrottling/src/Microsoft.AspNetCore.RequestThrottling.csproj index 0fc5168366..ff9d48ffaa 100644 --- a/src/Middleware/RequestThrottling/src/Microsoft.AspNetCore.RequestThrottling.csproj +++ b/src/Middleware/RequestThrottling/src/Microsoft.AspNetCore.RequestThrottling.csproj @@ -1,4 +1,4 @@ - + ASP.NET Core middleware for queuing incoming HTTP requests, to avoid threadpool starvation. @@ -12,6 +12,7 @@ + diff --git a/src/Middleware/RequestThrottling/src/QueuePolicies/IQueuePolicy.cs b/src/Middleware/RequestThrottling/src/QueuePolicies/IQueuePolicy.cs index ebaa753ffb..44a2b3b698 100644 --- a/src/Middleware/RequestThrottling/src/QueuePolicies/IQueuePolicy.cs +++ b/src/Middleware/RequestThrottling/src/QueuePolicies/IQueuePolicy.cs @@ -4,7 +4,7 @@ using System; using System.Threading.Tasks; -namespace Microsoft.AspNetCore.RequestThrottling +namespace Microsoft.AspNetCore.RequestThrottling.QueuePolicies { /// /// Queueing policies, meant to be used with the . diff --git a/src/Middleware/RequestThrottling/src/RequestThrottlingEventSource.cs b/src/Middleware/RequestThrottling/src/RequestThrottlingEventSource.cs new file mode 100644 index 0000000000..c4a997a417 --- /dev/null +++ b/src/Middleware/RequestThrottling/src/RequestThrottlingEventSource.cs @@ -0,0 +1,109 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information + +using System; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Text; +using System.Threading; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.RequestThrottling +{ + internal sealed class RequestThrottlingEventSource : EventSource + { + public static readonly RequestThrottlingEventSource Log = new RequestThrottlingEventSource(); + private static QueueFrame CachedNonTimerResult = new QueueFrame(null, Log); + + private PollingCounter _rejectedRequestsCounter; + private PollingCounter _queueLengthCounter; + private EventCounter _queueDuration; + + private long _rejectedRequests; + private int _queueLength; + + internal RequestThrottlingEventSource() + : base("Microsoft.AspNetCore.RequestThrottling") + { + } + + // Used for testing + internal RequestThrottlingEventSource(string eventSourceName) + : base(eventSourceName) + { + } + + [Event(1, Level = EventLevel.Warning)] + public void RequestRejected() + { + Interlocked.Increment(ref _rejectedRequests); + WriteEvent(1); + } + + [NonEvent] + public void QueueSkipped() + { + if (IsEnabled()) + { + _queueDuration.WriteMetric(0); + } + } + + [NonEvent] + public QueueFrame QueueTimer() + { + Interlocked.Increment(ref _queueLength); + + if (IsEnabled()) + { + return new QueueFrame(ValueStopwatch.StartNew(), this); + } + + return CachedNonTimerResult; + } + + internal struct QueueFrame : IDisposable + { + private ValueStopwatch? _timer; + private RequestThrottlingEventSource _parent; + + public QueueFrame(ValueStopwatch? timer, RequestThrottlingEventSource parent) + { + _timer = timer; + _parent = parent; + } + + public void Dispose() + { + Interlocked.Decrement(ref _parent._queueLength); + + if (_parent.IsEnabled() && _timer != null) + { + var duration = _timer.Value.GetElapsedTime().TotalMilliseconds; + _parent._queueDuration.WriteMetric(duration); + } + } + } + + protected override void OnEventCommand(EventCommandEventArgs command) + { + if (command.Command == EventCommand.Enable) + { + _rejectedRequestsCounter ??= new PollingCounter("requests-rejected", this, () => _rejectedRequests) + { + DisplayName = "Rejected Requests", + }; + + _queueLengthCounter ??= new PollingCounter("queue-length", this, () => _queueLength) + { + DisplayName = "Queue Length", + }; + + _queueDuration ??= new EventCounter("queue-duration", this) + { + DisplayName = "Average Time in Queue", + }; + } + } + } +} diff --git a/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs b/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs index 14e55a33e6..a30ba9ba06 100644 --- a/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs +++ b/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs @@ -2,9 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.RequestThrottling.QueuePolicies; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -20,8 +20,6 @@ namespace Microsoft.AspNetCore.RequestThrottling private readonly RequestDelegate _onRejected; private readonly ILogger _logger; - private int _queuedRequests; - /// /// Creates a new . /// @@ -49,19 +47,21 @@ namespace Microsoft.AspNetCore.RequestThrottling /// A that completes when the request leaves. public async Task Invoke(HttpContext context) { - Interlocked.Increment(ref _queuedRequests); + var waitInQueueTask = _queuePolicy.TryEnterAsync(); - var success = false; - try + if (waitInQueueTask.IsCompleted) { - success = await _queuePolicy.TryEnterAsync(); + RequestThrottlingEventSource.Log.QueueSkipped(); } - finally + else { - Interlocked.Decrement(ref _queuedRequests); + using (RequestThrottlingEventSource.Log.QueueTimer()) + { + await waitInQueueTask; + } } - if (success) + if (waitInQueueTask.Result) { try { @@ -74,20 +74,13 @@ namespace Microsoft.AspNetCore.RequestThrottling } else { + RequestThrottlingEventSource.Log.RequestRejected(); RequestThrottlingLog.RequestRejectedQueueFull(_logger); context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable; await _onRejected(context); } } - /// - /// The total number of requests waiting within the middleware - /// - public int QueuedRequestCount - { - get => _queuedRequests; - } - private static class RequestThrottlingLog { private static readonly Action _requestEnqueued = diff --git a/src/Middleware/RequestThrottling/test/Microsoft.AspNetCore.RequestThrottling.Tests.csproj b/src/Middleware/RequestThrottling/test/Microsoft.AspNetCore.RequestThrottling.Tests.csproj index 78b1c88692..759af99f21 100644 --- a/src/Middleware/RequestThrottling/test/Microsoft.AspNetCore.RequestThrottling.Tests.csproj +++ b/src/Middleware/RequestThrottling/test/Microsoft.AspNetCore.RequestThrottling.Tests.csproj @@ -5,7 +5,8 @@ - + + diff --git a/src/Middleware/RequestThrottling/test/MiddlewareTests.cs b/src/Middleware/RequestThrottling/test/MiddlewareTests.cs index 67e2f017e4..1e4dfa25b7 100644 --- a/src/Middleware/RequestThrottling/test/MiddlewareTests.cs +++ b/src/Middleware/RequestThrottling/test/MiddlewareTests.cs @@ -2,8 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.RequestThrottling.QueuePolicies; +using Microsoft.VisualStudio.TestPlatform.ObjectModel; using Xunit; namespace Microsoft.AspNetCore.RequestThrottling.Tests @@ -16,8 +19,8 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests var flag = false; var middleware = TestUtils.CreateTestMiddleware( - queue: TestStrategy.AlwaysPass, - next: (context) => { + queue: TestQueue.AlwaysTrue, + next: httpContext => { flag = true; return Task.CompletedTask; }); @@ -27,23 +30,12 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests } [Fact] - public async Task RequestRejectsIfQueueReturnsFalse() - { - var middleware = TestUtils.CreateTestMiddleware( - queue: TestStrategy.AlwaysReject); - - var context = new DefaultHttpContext(); - await middleware.Invoke(context); - Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode); - } - - [Fact] - public async void FullQueueInvokesOnRejected() + public async void RequestRejectsIfQueueReturnsFalse() { bool onRejectedInvoked = false; var middleware = TestUtils.CreateTestMiddleware( - queue: TestStrategy.AlwaysReject, + queue: TestQueue.AlwaysFalse, onRejected: httpContext => { onRejectedInvoked = true; @@ -57,14 +49,14 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests } [Fact] - public async void RequestsBlockedIfQueueFull() + public async void RequestsDoesNotEnterIfQueueFull() { var middleware = TestUtils.CreateTestMiddleware( - queue: TestStrategy.AlwaysReject, + queue: TestQueue.AlwaysFalse, next: httpContext => { // throttle should bounce the request; it should never get here - throw new NotImplementedException(); + throw new DivideByZeroException(); }); await middleware.Invoke(new DefaultHttpContext()).OrTimeout(); @@ -73,69 +65,118 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests [Fact] public void IncomingRequestsFillUpQueue() { - var middleware = TestUtils.CreateTestMiddleware( - queue: TestStrategy.AlwaysBlock); + var testQueue = TestQueue.AlwaysBlock; + var middleware = TestUtils.CreateTestMiddleware(testQueue); - Assert.Equal(0, middleware.QueuedRequestCount); + Assert.Equal(0, testQueue.QueuedRequests); _ = middleware.Invoke(new DefaultHttpContext()); - Assert.Equal(1, middleware.QueuedRequestCount); + Assert.Equal(1, testQueue.QueuedRequests); _ = middleware.Invoke(new DefaultHttpContext()); - Assert.Equal(2, middleware.QueuedRequestCount); + Assert.Equal(2, testQueue.QueuedRequests); } [Fact] - public async Task CleanupHappensEvenIfNextErrors() + public void EventCountersTrackQueuedRequests() + { + var blocker = new TaskCompletionSource(); + + var testQueue = new TestQueue( + onTryEnter: async (_) => + { + return await blocker.Task; + }); + var middleware = TestUtils.CreateTestMiddleware(testQueue); + + Assert.Equal(0, testQueue.QueuedRequests); + + var task1 = middleware.Invoke(new DefaultHttpContext()); + Assert.False(task1.IsCompleted); + Assert.Equal(1, testQueue.QueuedRequests); + + blocker.SetResult(true); + + Assert.Equal(0, testQueue.QueuedRequests); + } + + [Fact] + public async Task QueueOnExitCalledEvenIfNextErrors() { var flag = false; + var testQueue = new TestQueue( + onTryEnter: (_) => true, + onExit: () => { flag = true; }); + var middleware = TestUtils.CreateTestMiddleware( - queue: new TestStrategy( - invoke: (() => true), - onExit: () => { flag = true; }), + queue: testQueue, next: httpContext => { throw new DivideByZeroException(); }); - Assert.Equal(0, middleware.QueuedRequestCount); + Assert.Equal(0, testQueue.QueuedRequests); await Assert.ThrowsAsync(() => middleware.Invoke(new DefaultHttpContext())).OrTimeout(); - Assert.Equal(0, middleware.QueuedRequestCount); + Assert.Equal(0, testQueue.QueuedRequests); Assert.True(flag); } [Fact] public async void ExceptionThrownDuringOnRejected() { - TaskCompletionSource tsc = new TaskCompletionSource(); + TaskCompletionSource tcs = new TaskCompletionSource(); + + var concurrent = 0; + var testQueue = new TestQueue( + onTryEnter: (testQueue) => + { + if (concurrent > 0) + { + return false; + } + else + { + concurrent++; + return true; + } + }, + onExit: () => { concurrent--; }); var middleware = TestUtils.CreateTestMiddleware( + queue: testQueue, onRejected: httpContext => { throw new DivideByZeroException(); }, next: httpContext => { - return tsc.Task; + return tcs.Task; }); + // the first request enters the server, and is blocked by the tcs var firstRequest = middleware.Invoke(new DefaultHttpContext()); + Assert.Equal(1, concurrent); + Assert.Equal(0, testQueue.QueuedRequests); + // the second request is rejected with a 503 error. During the rejection, an error occurs var context = new DefaultHttpContext(); await Assert.ThrowsAsync(() => middleware.Invoke(context)).OrTimeout(); Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode); + Assert.Equal(1, concurrent); + Assert.Equal(0, testQueue.QueuedRequests); - tsc.SetResult(true); - + // the first request is unblocked, and the queue continues functioning as expected + tcs.SetResult(true); Assert.True(firstRequest.IsCompletedSuccessfully); + Assert.Equal(0, concurrent); + Assert.Equal(0, testQueue.QueuedRequests); var thirdRequest = middleware.Invoke(new DefaultHttpContext()); - Assert.True(thirdRequest.IsCompletedSuccessfully); - - Assert.Equal(0, middleware.QueuedRequestCount); + Assert.Equal(0, concurrent); + Assert.Equal(0, testQueue.QueuedRequests); } } } diff --git a/src/Middleware/RequestThrottling/test/PolicyTests/TailDropTests.cs b/src/Middleware/RequestThrottling/test/PolicyTests/TailDropTests.cs index d352325eb8..82ea4e2ceb 100644 --- a/src/Middleware/RequestThrottling/test/PolicyTests/TailDropTests.cs +++ b/src/Middleware/RequestThrottling/test/PolicyTests/TailDropTests.cs @@ -6,7 +6,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Internal; using Xunit; -namespace Microsoft.AspNetCore.RequestThrottling.Tests +namespace Microsoft.AspNetCore.RequestThrottling.Tests.PolicyTests { public class TailDropTests { diff --git a/src/Middleware/RequestThrottling/test/RequestThrottlingEventSourceTests.cs b/src/Middleware/RequestThrottling/test/RequestThrottlingEventSourceTests.cs new file mode 100644 index 0000000000..7725f6d7c4 --- /dev/null +++ b/src/Middleware/RequestThrottling/test/RequestThrottlingEventSourceTests.cs @@ -0,0 +1,148 @@ +// Copyright(c) .NET Foundation.All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.RequestThrottling.Tests +{ + public class RequestThrottlingEventSourceTests + { + [Fact] + public void MatchesNameAndGuid() + { + var eventSource = new RequestThrottlingEventSource(); + + Assert.Equal("Microsoft.AspNetCore.RequestThrottling", eventSource.Name); + Assert.Equal(Guid.Parse("436f1cb1-8acc-56c0-86ec-e0832bd696ed"), eventSource.Guid); + } + + [Fact] + public void RecordsRequestsRejected() + { + // Arrange + var expectedId = 1; + var eventListener = new TestEventListener(expectedId); + var eventSource = GetRequestThrottlingEventSource(); + eventListener.EnableEvents(eventSource, EventLevel.Informational); + + // Act + eventSource.RequestRejected(); + + // Assert + var eventData = eventListener.EventData; + Assert.NotNull(eventData); + Assert.Equal(expectedId, eventData.EventId); + Assert.Equal(EventLevel.Warning, eventData.Level); + Assert.Same(eventSource, eventData.EventSource); + Assert.Null(eventData.Message); + Assert.Empty(eventData.Payload); + } + + [Fact] + public async Task TracksQueueLength() + { + // Arrange + using var eventListener = new TestCounterListener(new[] { + "queue-length", + "queue-duration", + "requests-rejected", + }); + + using var eventSource = GetRequestThrottlingEventSource(); + + using var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + var lengthValues = eventListener.GetCounterValues("queue-length", timeoutTokenSource.Token).GetAsyncEnumerator(); + + eventListener.EnableEvents(eventSource, EventLevel.Informational, EventKeywords.None, + new Dictionary + { + {"EventCounterIntervalSec", "1" } + }); + + // Act + eventSource.RequestRejected(); + + Assert.True(await UntilValueMatches(lengthValues, 0)); + using (eventSource.QueueTimer()) + { + Assert.True(await UntilValueMatches(lengthValues, 1)); + + using (eventSource.QueueTimer()) + { + Assert.True(await UntilValueMatches(lengthValues, 2)); + } + + Assert.True(await UntilValueMatches(lengthValues, 1)); + } + + Assert.True(await UntilValueMatches(lengthValues, 0)); + } + + [Fact] + public async Task TracksDurationSpentInQueue() + { + // Arrange + using var eventListener = new TestCounterListener(new[] { + "queue-length", + "queue-duration", + "requests-rejected", + }); + + using var eventSource = GetRequestThrottlingEventSource(); + + using var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + var durationValues = eventListener.GetCounterValues("queue-duration", timeoutTokenSource.Token).GetAsyncEnumerator(); + + eventListener.EnableEvents(eventSource, EventLevel.Informational, EventKeywords.None, + new Dictionary + { + {"EventCounterIntervalSec", "1" } + }); + + // Act + Assert.True(await UntilValueMatches(durationValues, 0)); + + using (eventSource.QueueTimer()) + { + Assert.True(await UntilValueMatches(durationValues, 0)); + } + + // check that something (anything!) has been written + while (await durationValues.MoveNextAsync()) + { + if (durationValues.Current > 0) + { + return; + } + } + + throw new TimeoutException(); + } + + private async Task UntilValueMatches(IAsyncEnumerator enumerator, int value) + { + while (await enumerator.MoveNextAsync()) + { + if (enumerator.Current == value) + { + return true; + } + } + + return false; + } + + private static RequestThrottlingEventSource GetRequestThrottlingEventSource() + { + return new RequestThrottlingEventSource(Guid.NewGuid().ToString()); + } + } +} diff --git a/src/Middleware/RequestThrottling/test/TestUtils.cs b/src/Middleware/RequestThrottling/test/TestUtils.cs index 4617914294..008b3ab4a1 100644 --- a/src/Middleware/RequestThrottling/test/TestUtils.cs +++ b/src/Middleware/RequestThrottling/test/TestUtils.cs @@ -69,30 +69,33 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests } } - public class TestStrategy : IQueuePolicy + internal class TestQueue : IQueuePolicy { - private Func> _invoke { get; } + private Func> _onTryEnter { get; } private Action _onExit { get; } - public TestStrategy(Func> invoke, Action onExit = null) + private int _queuedRequests; + public int QueuedRequests { get => _queuedRequests; } + + public TestQueue(Func> onTryEnter, Action onExit = null) { - _invoke = invoke; + _onTryEnter = onTryEnter; _onExit = onExit ?? (() => { }); } - public TestStrategy(Func invoke, Action onExit = null) - : this(async () => - { - await Task.CompletedTask; - return invoke(); - }, - onExit) - { } - + public TestQueue(Func onTryEnter, Action onExit = null) : + this(async (state) => + { + await Task.CompletedTask; + return onTryEnter(state); + }, onExit) { } + public async Task TryEnterAsync() { - await Task.CompletedTask; - return await _invoke(); + Interlocked.Increment(ref _queuedRequests); + var result = await _onTryEnter(this); + Interlocked.Decrement(ref _queuedRequests); + return result; } public void OnExit() @@ -100,14 +103,14 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests _onExit(); } - public static TestStrategy AlwaysReject = - new TestStrategy(() => false); + public static TestQueue AlwaysFalse = + new TestQueue((_) => false); - public static TestStrategy AlwaysPass = - new TestStrategy(() => true); + public static TestQueue AlwaysTrue = + new TestQueue((_) => true); - public static TestStrategy AlwaysBlock = - new TestStrategy(async () => + public static TestQueue AlwaysBlock = + new TestQueue(async (_) => { await new SemaphoreSlim(0).WaitAsync(); return false; diff --git a/src/Shared/EventSource.Testing/TestCounterListener.cs b/src/Shared/EventSource.Testing/TestCounterListener.cs new file mode 100644 index 0000000000..f31551c83e --- /dev/null +++ b/src/Shared/EventSource.Testing/TestCounterListener.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Threading; +using System.Threading.Channels; + +namespace Microsoft.AspNetCore.Internal +{ + internal class TestCounterListener : EventListener + { + private readonly Dictionary> _counters = new Dictionary>(); + + /// + /// Creates a new TestCounterListener. + /// + /// The names of ALL counters for the event source. You must name each counter, even if you do not intend to use it. + public TestCounterListener(string[] counterNames) + { + foreach (var item in counterNames) + { + _counters[item] = Channel.CreateUnbounded(); + } + } + + public IAsyncEnumerable GetCounterValues(string counterName, CancellationToken cancellationToken = default) + { + return _counters[counterName].Reader.ReadAllAsync(cancellationToken); + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + if (eventData.EventName == "EventCounters") + { + var payload = (IDictionary)eventData.Payload[0]; + var counter = (string)payload["Name"]; + payload.TryGetValue("Increment", out var increment); + payload.TryGetValue("Mean", out var mean); + var writer = _counters[counter].Writer; + writer.TryWrite((double)(increment ?? mean)); + } + } + } +}