Dylan/request queue di (#11167)

Decoupled Middleware from QueuePolicy
This commit is contained in:
Dylan Dmitri Gray 2019-06-13 16:36:21 -07:00 committed by GitHub
parent 82d2b4f4d0
commit c24c4cac01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 340 additions and 305 deletions

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
@ -6,6 +6,10 @@
<!--<StartupObject>Microsoft.AspNetCore.RequestThrottling.Microbenchmarks.Test</StartupObject>-->
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\test\Microsoft.AspNetCore.RequestThrottling.Tests.csproj" />
</ItemGroup>
<ItemGroup>
<Reference Include="BenchmarkDotNet" />
<Reference Include="Microsoft.AspNetCore.BenchmarkRunner.Sources" />

View File

@ -1,14 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Running;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Microsoft.AspNetCore.RequestThrottling.Tests;
namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
{
@ -24,17 +21,10 @@ namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
{
_restOfServer = YieldsThreadInternally ? (RequestDelegate)YieldsThread : (RequestDelegate)CompletesImmediately;
var options = new RequestThrottlingOptions
{
MaxConcurrentRequests = 8,
RequestQueueLimit = _numRequests
};
_middleware = new RequestThrottlingMiddleware(
next: _restOfServer,
loggerFactory: NullLoggerFactory.Instance,
options: Options.Create(options)
);
_middleware = TestUtils.CreateTestMiddleware_TailDrop(
maxConcurrentRequests: 1,
requestQueueLimit: 0,
next: _restOfServer);
}
[Params(false, true)]

View File

@ -1,12 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Microsoft.AspNetCore.RequestThrottling.Tests;
namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
{
@ -24,16 +23,10 @@ namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
var options = new RequestThrottlingOptions
{
MaxConcurrentRequests = MaxConcurrentRequests,
RequestQueueLimit = _numRequests
};
_middleware = new RequestThrottlingMiddleware(
next: (RequestDelegate)_incrementAndCheck,
loggerFactory: NullLoggerFactory.Instance,
options: Options.Create(options)
_middleware = TestUtils.CreateTestMiddleware_TailDrop(
maxConcurrentRequests: MaxConcurrentRequests,
requestQueueLimit: _numRequests,
next: IncrementAndCheck
);
}
@ -44,7 +37,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
_mres.Reset();
}
private async Task _incrementAndCheck(HttpContext context)
private async Task IncrementAndCheck(HttpContext context)
{
if (Interlocked.Increment(ref _requestCount) == _numRequests)
{
@ -59,7 +52,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Microbenchmarks
{
for (int i = 0; i < _numRequests; i++)
{
_ = _incrementAndCheck(null);
_ = IncrementAndCheck(null);
}
_mres.Wait();

View File

@ -10,18 +10,37 @@ namespace Microsoft.AspNetCore.Builder
}
namespace Microsoft.AspNetCore.RequestThrottling
{
public partial interface IQueuePolicy
{
void OnExit();
System.Threading.Tasks.Task<bool> TryEnterAsync();
}
public partial class RequestThrottlingMiddleware
{
public RequestThrottlingMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.RequestThrottling.RequestThrottlingOptions> options) { }
public int ActiveRequestCount { get { throw null; } }
public RequestThrottlingMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.AspNetCore.RequestThrottling.IQueuePolicy queue, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.RequestThrottling.RequestThrottlingOptions> options) { }
public int QueuedRequestCount { get { throw null; } }
[System.Diagnostics.DebuggerStepThroughAttribute]
public System.Threading.Tasks.Task Invoke(Microsoft.AspNetCore.Http.HttpContext context) { throw null; }
}
public partial class RequestThrottlingOptions
{
public RequestThrottlingOptions() { }
public int? MaxConcurrentRequests { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public Microsoft.AspNetCore.Http.RequestDelegate OnRejected { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
}
}
namespace Microsoft.AspNetCore.RequestThrottling.Policies
{
public partial class TailDropOptions
{
public TailDropOptions() { }
public int MaxConcurrentRequests { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public int RequestQueueLimit { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
}
}
namespace Microsoft.Extensions.DependencyInjection
{
public static partial class QueuePolicyServiceCollectionExtensions
{
public static Microsoft.Extensions.DependencyInjection.IServiceCollection AddTailDropQueue(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, System.Action<Microsoft.AspNetCore.RequestThrottling.Policies.TailDropOptions> configure) { throw null; }
}
}

View File

@ -6,7 +6,6 @@ using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RequestThrottling;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
@ -19,11 +18,13 @@ namespace RequestThrottlingSample
// For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
{
services.Configure<RequestThrottlingOptions>(options =>
services.AddTailDropQueue((options) =>
{
options.MaxConcurrentRequests = 2;
options.MaxConcurrentRequests = 4;
options.RequestQueueLimit = 0;
});
services.AddLogging();
}
public void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory)

View File

@ -0,0 +1,27 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.RequestThrottling
{
/// <summary>
/// Queueing policies, meant to be used with the <see cref="RequestThrottlingMiddleware"></see>.
/// </summary>
public interface IQueuePolicy
{
/// <summary>
/// Called for every incoming request.
/// When it returns 'true' the request procedes to the server.
/// When it returns 'false' the request is rejected immediately.
/// </summary>
Task<bool> TryEnterAsync();
/// <summary>
/// Called after successful requests have been returned from the server.
/// Does NOT get called for rejected requests.
/// </summary>
void OnExit();
}
}

View File

@ -1,17 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.RequestThrottling.Internal
{
interface IRequestQueue : IDisposable
{
int TotalRequests { get; }
Task<bool> TryEnterQueueAsync();
void Release();
}
}

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Description>ASP.NET Core middleware for queuing incoming HTTP requests, to avoid threadpool starvation.</Description>

View File

@ -0,0 +1,34 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;
using System.Xml.Schema;
using Microsoft.AspNetCore.RequestThrottling;
using Microsoft.AspNetCore.RequestThrottling.Policies;
using Microsoft.Extensions.DependencyInjection.Extensions;
namespace Microsoft.Extensions.DependencyInjection
{
/// <summary>
/// Contains methods for adding Q
/// </summary>
public static class QueuePolicyServiceCollectionExtensions
{
/// <summary>
/// Tells <see cref="RequestThrottlingMiddleware"/> to use a TailDrop queue as its queueing strategy.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> to add services to.</param>
/// <param name="configure">Set the options used by the queue.
/// Mandatory, since <see cref="TailDropOptions.MaxConcurrentRequests"></see> must be provided.</param>
/// <returns></returns>
public static IServiceCollection AddTailDropQueue(this IServiceCollection services, Action<TailDropOptions> configure)
{
services.Configure<TailDropOptions>(configure);
services.AddSingleton<IQueuePolicy, TailDrop>();
return services;
}
}
}

View File

@ -4,10 +4,11 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.RequestThrottling.Internal
namespace Microsoft.AspNetCore.RequestThrottling.Policies
{
internal class TailDrop : IRequestQueue
internal class TailDrop : IQueuePolicy, IDisposable
{
private readonly int _maxConcurrentRequests;
private readonly int _requestQueueLimit;
@ -16,14 +17,24 @@ namespace Microsoft.AspNetCore.RequestThrottling.Internal
private object _totalRequestsLock = new object();
public int TotalRequests { get; private set; }
public TailDrop(int maxConcurrentRequests, int requestQueueLimit)
public TailDrop(IOptions<TailDropOptions> options)
{
_maxConcurrentRequests = maxConcurrentRequests;
_requestQueueLimit = requestQueueLimit;
_maxConcurrentRequests = options.Value.MaxConcurrentRequests;
if (_maxConcurrentRequests <= 0)
{
throw new ArgumentException(nameof(_maxConcurrentRequests), "MaxConcurrentRequests must be a positive integer.");
}
_requestQueueLimit = options.Value.RequestQueueLimit;
if (_requestQueueLimit < 0)
{
throw new ArgumentException(nameof(_requestQueueLimit), "The RequestQueueLimit cannot be a negative number.");
}
_serverSemaphore = new SemaphoreSlim(_maxConcurrentRequests);
}
public async Task<bool> TryEnterQueueAsync()
public async Task<bool> TryEnterAsync()
{
// a return value of 'false' indicates that the request is rejected
// a return value of 'true' indicates that the request may proceed
@ -44,7 +55,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Internal
return true;
}
public void Release()
public void OnExit()
{
_serverSemaphore.Release();

View File

@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.RequestThrottling.Policies
{
/// <summary>
/// Specifies options for the <see cref="TailDrop"/>
/// </summary>
public class TailDropOptions
{
/// <summary>
/// Maximum number of concurrent requests. Any extras will be queued on the server.
/// This option is highly application dependant, and must be configured by the application.
/// </summary>
public int MaxConcurrentRequests { get; set; }
/// <summary>
/// Maximum number of queued requests before the server starts rejecting connections with '503 Service Unavailible'.
/// Defaults to 5000 queued requests.
/// </summary>
public int RequestQueueLimit { get; set; } = 5000;
}
}

View File

@ -2,10 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using System.Xml.Schema;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RequestThrottling.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
@ -16,51 +16,31 @@ namespace Microsoft.AspNetCore.RequestThrottling
/// </summary>
public class RequestThrottlingMiddleware
{
private readonly IRequestQueue _requestQueue;
private readonly IQueuePolicy _queuePolicy;
private readonly RequestDelegate _next;
private readonly RequestThrottlingOptions _requestThrottlingOptions;
private readonly RequestDelegate _onRejected;
private readonly ILogger _logger;
private int _queuedRequests;
/// <summary>
/// Creates a new <see cref="RequestThrottlingMiddleware"/>.
/// </summary>
/// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
/// <param name="options">The <see cref="RequestThrottlingOptions"/> containing the initialization parameters.</param>
public RequestThrottlingMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IOptions<RequestThrottlingOptions> options)
/// <param name="queue">The queueing strategy to use for the server.</param>
/// <param name="options">The options for the middleware, currently containing the 'OnRejected' callback.</param>
public RequestThrottlingMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IQueuePolicy queue, IOptions<RequestThrottlingOptions> options)
{
_requestThrottlingOptions = options.Value;
if (_requestThrottlingOptions.MaxConcurrentRequests == null)
{
throw new ArgumentException("The value of 'options.MaxConcurrentRequests' must be specified.", nameof(options));
}
if (_requestThrottlingOptions.MaxConcurrentRequests <= 0)
{
throw new ArgumentOutOfRangeException(nameof(options), "The value of `options.MaxConcurrentRequests` must be a positive integer.");
}
if (_requestThrottlingOptions.RequestQueueLimit < 0)
{
throw new ArgumentException("The value of 'options.RequestQueueLimit' must be a positive integer.", nameof(options));
}
if (_requestThrottlingOptions.OnRejected == null)
if (options.Value.OnRejected == null)
{
throw new ArgumentException("The value of 'options.OnRejected' must not be null.", nameof(options));
}
_next = next;
_logger = loggerFactory.CreateLogger<RequestThrottlingMiddleware>();
if (_requestThrottlingOptions.ServerAlwaysBlocks)
{
// note: this option for testing only. Blocks all requests from entering the server.
_requestQueue = new TailDrop(0, _requestThrottlingOptions.RequestQueueLimit);
}
else
{
_requestQueue = new TailDrop(_requestThrottlingOptions.MaxConcurrentRequests.Value, _requestThrottlingOptions.RequestQueueLimit);
}
_onRejected = options.Value.OnRejected;
_queuePolicy = queue;
}
/// <summary>
@ -70,48 +50,45 @@ namespace Microsoft.AspNetCore.RequestThrottling
/// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
public async Task Invoke(HttpContext context)
{
var waitInQueueTask = _requestQueue.TryEnterQueueAsync();
if (waitInQueueTask.IsCompletedSuccessfully && waitInQueueTask.Result)
{
RequestThrottlingLog.RequestRunImmediately(_logger, ActiveRequestCount);
}
else
{
RequestThrottlingLog.RequestEnqueued(_logger, ActiveRequestCount);
await waitInQueueTask;
RequestThrottlingLog.RequestDequeued(_logger, ActiveRequestCount);
}
if (!waitInQueueTask.Result)
{
RequestThrottlingLog.RequestRejectedQueueFull(_logger);
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await _requestThrottlingOptions.OnRejected(context);
return;
}
Interlocked.Increment(ref _queuedRequests);
var success = false;
try
{
await _next(context);
success = await _queuePolicy.TryEnterAsync();
}
finally
{
_requestQueue.Release();
Interlocked.Decrement(ref _queuedRequests);
}
if (success)
{
try
{
await _next(context);
}
finally
{
_queuePolicy.OnExit();
}
}
else
{
RequestThrottlingLog.RequestRejectedQueueFull(_logger);
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await _onRejected(context);
}
}
/// <summary>
/// The number of requests currently on the server.
/// Cannot exceeed the sum of <see cref="RequestThrottlingOptions.RequestQueueLimit"> and </see>/><see cref="RequestThrottlingOptions.MaxConcurrentRequests"/>.
/// The total number of requests waiting within the middleware
/// </summary>
public int ActiveRequestCount
public int QueuedRequestCount
{
get => _requestQueue.TotalRequests;
get => _queuedRequests;
}
// TODO :: update log wording to reflect the changes
private static class RequestThrottlingLog
{
private static readonly Action<ILogger, int, Exception> _requestEnqueued =

View File

@ -3,7 +3,6 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RequestThrottling;
namespace Microsoft.AspNetCore.RequestThrottling
{
@ -12,19 +11,6 @@ namespace Microsoft.AspNetCore.RequestThrottling
/// </summary>
public class RequestThrottlingOptions
{
/// <summary>
/// Maximum number of concurrent requests. Any extras will be queued on the server.
/// This is null by default because the correct value is application specific. This option must be configured by the application.
/// </summary>
public int? MaxConcurrentRequests { get; set; }
/// <summary>
/// Maximum number of queued requests before the server starts rejecting connections with '503 Service Unavailible'.
/// Setting this value to 0 will disable the queue; all requests will either immediately enter the server or be rejected.
/// Defaults to 5000 queued requests.
/// </summary>
public int RequestQueueLimit { get; set; } = 5000;
/// <summary>
/// A <see cref="RequestDelegate"/> that handles requests rejected by this middleware.
/// If it doesn't modify the response, an empty 503 response will be written.
@ -33,10 +19,5 @@ namespace Microsoft.AspNetCore.RequestThrottling
{
return Task.CompletedTask;
};
/// <summary>
/// For internal testing only. If true, no requests will enter the server.
/// </summary>
internal bool ServerAlwaysBlocks { get; set; } = false;
}
}

View File

@ -4,7 +4,6 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Internal;
using Xunit;
namespace Microsoft.AspNetCore.RequestThrottling.Tests
@ -12,117 +11,30 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
public class MiddlewareTests
{
[Fact]
public async Task RequestsCanEnterIfSpaceAvailible()
public async Task RequestsCallNextIfQueueReturnsTrue()
{
var middleware = TestUtils.CreateTestMiddleware(maxConcurrentRequests: 1);
var context = new DefaultHttpContext();
// a request should go through with no problems
await middleware.Invoke(context).OrTimeout();
}
[Fact]
public async Task SemaphoreStatePreservedIfRequestsError()
{
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 1,
next: httpContext =>
{
throw new DivideByZeroException();
});
Assert.Equal(0, middleware.ActiveRequestCount);
await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(new DefaultHttpContext())).OrTimeout();
Assert.Equal(0, middleware.ActiveRequestCount);
}
[Fact]
public async Task QueuedRequestsContinueWhenSpaceBecomesAvailible()
{
var blocker = new SyncPoint();
var firstRequest = true;
var flag = false;
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 1,
next: httpContext =>
{
if (firstRequest)
{
firstRequest = false;
return blocker.WaitToContinue();
}
queue: TestStrategy.AlwaysPass,
next: (context) => {
flag = true;
return Task.CompletedTask;
});
// t1 (as the first request) is blocked by the tcs blocker
var t1 = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(1, middleware.ActiveRequestCount);
// t2 is blocked from entering the server since t1 already exists there
// note: increasing MaxConcurrentRequests would allow t2 through while t1 is blocked
var t2 = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(2, middleware.ActiveRequestCount);
// unblock the first task, and the second should follow
blocker.Continue();
await t1.OrTimeout();
await t2.OrTimeout();
await middleware.Invoke(new DefaultHttpContext());
Assert.True(flag);
}
[Fact]
public void InvalidArgumentIfMaxConcurrentRequestsIsNull()
{
var ex = Assert.Throws<ArgumentException>(() =>
{
TestUtils.CreateTestMiddleware(maxConcurrentRequests: null);
});
Assert.Equal("options", ex.ParamName);
}
[Fact]
public async void RequestsBlockedIfQueueFull()
{
var middleware = TestUtils.CreateBlockingTestMiddleware(
requestQueueLimit: 0,
next: httpContext =>
{
// throttle should bounce the request; it should never get here
throw new NotImplementedException();
});
await middleware.Invoke(new DefaultHttpContext()).OrTimeout();
}
[Fact]
public async void FullQueueResultsIn503Error()
{
var middleware = TestUtils.CreateBlockingTestMiddleware(requestQueueLimit: 0);
var context = new DefaultHttpContext();
await middleware.Invoke(context).OrTimeout();
Assert.Equal(503, context.Response.StatusCode);
}
[Fact]
public void MultipleRequestsFillUpQueue()
public async Task RequestRejectsIfQueueReturnsFalse()
{
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 1,
requestQueueLimit: 10,
next: httpContext =>
{
return Task.Delay(TimeSpan.FromSeconds(30));
});
queue: TestStrategy.AlwaysReject);
Assert.Equal(0, middleware.ActiveRequestCount);
var _ = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(1, middleware.ActiveRequestCount);
_ = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(2, middleware.ActiveRequestCount);
var context = new DefaultHttpContext();
await middleware.Invoke(context);
Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
}
[Fact]
@ -130,8 +42,8 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
{
bool onRejectedInvoked = false;
var middleware = TestUtils.CreateBlockingTestMiddleware(
requestQueueLimit: 0,
var middleware = TestUtils.CreateTestMiddleware(
queue: TestStrategy.AlwaysReject,
onRejected: httpContext =>
{
onRejectedInvoked = true;
@ -144,14 +56,62 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
}
[Fact]
public async void RequestsBlockedIfQueueFull()
{
var middleware = TestUtils.CreateTestMiddleware(
queue: TestStrategy.AlwaysReject,
next: httpContext =>
{
// throttle should bounce the request; it should never get here
throw new NotImplementedException();
});
await middleware.Invoke(new DefaultHttpContext()).OrTimeout();
}
[Fact]
public void IncomingRequestsFillUpQueue()
{
var middleware = TestUtils.CreateTestMiddleware(
queue: TestStrategy.AlwaysBlock);
Assert.Equal(0, middleware.QueuedRequestCount);
_ = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(1, middleware.QueuedRequestCount);
_ = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(2, middleware.QueuedRequestCount);
}
[Fact]
public async Task CleanupHappensEvenIfNextErrors()
{
var flag = false;
var middleware = TestUtils.CreateTestMiddleware(
queue: new TestStrategy(
invoke: (() => true),
onExit: () => { flag = true; }),
next: httpContext =>
{
throw new DivideByZeroException();
});
Assert.Equal(0, middleware.QueuedRequestCount);
await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(new DefaultHttpContext())).OrTimeout();
Assert.Equal(0, middleware.QueuedRequestCount);
Assert.True(flag);
}
[Fact]
public async void ExceptionThrownDuringOnRejected()
{
TaskCompletionSource<bool> tsc = new TaskCompletionSource<bool>();
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 1,
requestQueueLimit: 0,
onRejected: httpContext =>
{
throw new DivideByZeroException();
@ -175,7 +135,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
Assert.True(thirdRequest.IsCompletedSuccessfully);
Assert.Equal(0, middleware.ActiveRequestCount);
Assert.Equal(0, middleware.QueuedRequestCount);
}
}
}

View File

@ -2,62 +2,50 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.AspNetCore.RequestThrottling.Internal;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Internal;
using Xunit;
namespace Microsoft.AspNetCore.RequestThrottling.Tests
{
public class RequestQueueTests
public class TailDropTests
{
[Fact]
public async Task LimitsIncomingRequests()
{
using var s = TestUtils.CreateRequestQueue(1);
Assert.Equal(0, s.TotalRequests);
Assert.True(await s.TryEnterQueueAsync().OrTimeout());
Assert.Equal(1, s.TotalRequests);
s.Release();
Assert.Equal(0, s.TotalRequests);
}
[Fact]
public void DoesNotWaitIfSpaceAvailible()
{
using var s = TestUtils.CreateRequestQueue(2);
using var s = TestUtils.CreateTailDropQueue(2);
var t1 = s.TryEnterQueueAsync();
var t1 = s.TryEnterAsync();
Assert.True(t1.IsCompleted);
var t2 = s.TryEnterQueueAsync();
var t2 = s.TryEnterAsync();
Assert.True(t2.IsCompleted);
var t3 = s.TryEnterQueueAsync();
var t3 = s.TryEnterAsync();
Assert.False(t3.IsCompleted);
}
[Fact]
public async Task WaitsIfNoSpaceAvailible()
{
using var s = TestUtils.CreateRequestQueue(1);
Assert.True(await s.TryEnterQueueAsync().OrTimeout());
using var s = TestUtils.CreateTailDropQueue(1);
Assert.True(await s.TryEnterAsync().OrTimeout());
var waitingTask = s.TryEnterQueueAsync();
var waitingTask = s.TryEnterAsync();
Assert.False(waitingTask.IsCompleted);
s.Release();
s.OnExit();
Assert.True(await waitingTask.OrTimeout());
}
[Fact]
public async Task IsEncapsulated()
{
using var s1 = TestUtils.CreateRequestQueue(1);
using var s2 = TestUtils.CreateRequestQueue(1);
using var s1 = TestUtils.CreateTailDropQueue(1);
using var s2 = TestUtils.CreateTailDropQueue(1);
Assert.True(await s1.TryEnterQueueAsync().OrTimeout());
Assert.True(await s2.TryEnterQueueAsync().OrTimeout());
Assert.True(await s1.TryEnterAsync().OrTimeout());
Assert.True(await s2.TryEnterAsync().OrTimeout());
}
}
}

View File

@ -1,53 +1,96 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RequestThrottling.Policies;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Microsoft.AspNetCore.RequestThrottling.Internal;
namespace Microsoft.AspNetCore.RequestThrottling.Tests
{
public static class TestUtils
{
public static RequestThrottlingMiddleware CreateTestMiddleware(int? maxConcurrentRequests, int requestQueueLimit = 5000, RequestDelegate onRejected = null, RequestDelegate next = null)
public static RequestThrottlingMiddleware CreateTestMiddleware(IQueuePolicy queue = null, RequestDelegate onRejected = null, RequestDelegate next = null)
{
var options = new RequestThrottlingOptions
var options = Options.Create(new RequestThrottlingOptions
{
MaxConcurrentRequests = maxConcurrentRequests,
RequestQueueLimit = requestQueueLimit
};
return BuildFromOptions(options, onRejected, next);
}
public static RequestThrottlingMiddleware CreateBlockingTestMiddleware(int requestQueueLimit = 5000, RequestDelegate onRejected = null, RequestDelegate next = null)
{
var options = new RequestThrottlingOptions
{
MaxConcurrentRequests = 999,
RequestQueueLimit = requestQueueLimit,
ServerAlwaysBlocks = true
};
return BuildFromOptions(options, onRejected, next);
}
private static RequestThrottlingMiddleware BuildFromOptions(RequestThrottlingOptions options, RequestDelegate onRejected, RequestDelegate next)
{
if (onRejected != null)
{
options.OnRejected = onRejected;
}
OnRejected = onRejected ?? (context => Task.CompletedTask),
});
return new RequestThrottlingMiddleware(
next: next ?? (context => Task.CompletedTask),
loggerFactory: NullLoggerFactory.Instance,
options: Options.Create(options)
queue: queue ?? CreateTailDropQueue(1, 0),
options: options
);
}
internal static IRequestQueue CreateRequestQueue(int maxConcurrentRequests) => new TailDrop(maxConcurrentRequests, 5000);
public static RequestThrottlingMiddleware CreateTestMiddleware_TailDrop(int maxConcurrentRequests, int requestQueueLimit, RequestDelegate onRejected = null, RequestDelegate next = null)
{
return CreateTestMiddleware(
queue: CreateTailDropQueue(maxConcurrentRequests, requestQueueLimit),
onRejected: onRejected,
next: next
);
}
internal static TailDrop CreateTailDropQueue(int maxConcurrentRequests, int requestQueueLimit = 5000)
{
var options = Options.Create(new TailDropOptions
{
MaxConcurrentRequests = maxConcurrentRequests,
RequestQueueLimit = requestQueueLimit
});
return new TailDrop(options);
}
}
public class TestStrategy : IQueuePolicy
{
private Func<Task<bool>> _invoke { get; }
private Action _onExit { get; }
public TestStrategy(Func<Task<bool>> invoke, Action onExit = null)
{
_invoke = invoke;
_onExit = onExit ?? (() => { });
}
public TestStrategy(Func<bool> invoke, Action onExit = null)
: this(async () =>
{
await Task.CompletedTask;
return invoke();
},
onExit)
{ }
public async Task<bool> TryEnterAsync()
{
await Task.CompletedTask;
return await _invoke();
}
public void OnExit()
{
_onExit();
}
public static TestStrategy AlwaysReject =
new TestStrategy(() => false);
public static TestStrategy AlwaysPass =
new TestStrategy(() => true);
public static TestStrategy AlwaysBlock =
new TestStrategy(async () =>
{
await new SemaphoreSlim(0).WaitAsync();
return false;
});
}
}