Add OnRejected to Request Throttling middleware (#10817)

* Add OnRejected to Request Throttling middleware

* Remove IApplicationBuilder extension with option

* Add test

* Add summary for OnRejected
This commit is contained in:
Kahbazi 2019-06-07 22:45:48 +04:30 committed by Dylan Dmitri Gray
parent d058e0f495
commit 1dff8cbdbc
5 changed files with 88 additions and 8 deletions

View File

@ -20,6 +20,7 @@ namespace Microsoft.AspNetCore.RequestThrottling
{
public RequestThrottlingOptions() { }
public int? MaxConcurrentRequests { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public Microsoft.AspNetCore.Http.RequestDelegate OnRejected { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public int RequestQueueLimit { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
}
}

View File

@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.RequestThrottling
{
private readonly RequestQueue _requestQueue;
private readonly RequestDelegate _next;
private readonly RequestThrottlingOptions _requestThrottlingOptions;
private readonly ILogger _logger;
/// <summary>
@ -28,20 +29,27 @@ namespace Microsoft.AspNetCore.RequestThrottling
/// <param name="options">The <see cref="RequestThrottlingOptions"/> containing the initialization parameters.</param>
public RequestThrottlingMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IOptions<RequestThrottlingOptions> options)
{
if (options.Value.MaxConcurrentRequests == null)
_requestThrottlingOptions = options.Value;
if (_requestThrottlingOptions.MaxConcurrentRequests == null)
{
throw new ArgumentException("The value of 'options.MaxConcurrentRequests' must be specified.", nameof(options));
}
if (options.Value.RequestQueueLimit < 0)
if (_requestThrottlingOptions.RequestQueueLimit < 0)
{
throw new ArgumentException("The value of 'options.RequestQueueLimit' must be a positive integer.", nameof(options));
}
if (_requestThrottlingOptions.OnRejected == null)
{
throw new ArgumentException("The value of 'options.OnRejected' must not be null.", nameof(options));
}
_next = next;
_logger = loggerFactory.CreateLogger<RequestThrottlingMiddleware>();
_requestQueue = new RequestQueue(
options.Value.MaxConcurrentRequests.Value,
options.Value.RequestQueueLimit);
_requestThrottlingOptions.MaxConcurrentRequests.Value,
_requestThrottlingOptions.RequestQueueLimit);
}
/// <summary>
@ -56,6 +64,7 @@ namespace Microsoft.AspNetCore.RequestThrottling
{
RequestThrottlingLog.RequestRejectedQueueFull(_logger);
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await _requestThrottlingOptions.OnRejected(context);
return;
}
else if (!waitInQueueTask.IsCompletedSuccessfully)

View File

@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RequestThrottling;
namespace Microsoft.AspNetCore.RequestThrottling
@ -22,5 +24,14 @@ namespace Microsoft.AspNetCore.RequestThrottling
/// Defaults to 5000 queued requests.
/// </summary>
public int RequestQueueLimit { get; set; } = 5000;
/// <summary>
/// A <see cref="RequestDelegate"/> that handles requests rejected by this middleware.
/// If it doesn't modify the response, an empty 503 response will be written.
/// </summary>
public RequestDelegate OnRejected { get; set; } = context =>
{
return Task.CompletedTask;
};
}
}

View File

@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
Assert.Equal(0, middleware.ActiveRequestCount);
await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(new DefaultHttpContext()));
await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(new DefaultHttpContext())).OrTimeout();
Assert.Equal(0, middleware.ActiveRequestCount);
}
@ -93,7 +93,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
throw new NotImplementedException();
});
await middleware.Invoke(new DefaultHttpContext());
await middleware.Invoke(new DefaultHttpContext()).OrTimeout();
}
[Fact]
@ -104,7 +104,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
requestQueueLimit: 0);
var context = new DefaultHttpContext();
await middleware.Invoke(context);
await middleware.Invoke(context).OrTimeout();
Assert.Equal(503, context.Response.StatusCode);
}
@ -127,5 +127,59 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
_ = middleware.Invoke(new DefaultHttpContext());
Assert.Equal(2, middleware.ActiveRequestCount);
}
[Fact]
public async void FullQueueInvokesOnRejected()
{
bool onRejectedInvoked = false;
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 0,
requestQueueLimit: 0,
onRejected: httpContext =>
{
onRejectedInvoked = true;
return Task.CompletedTask;
});
var context = new DefaultHttpContext();
await middleware.Invoke(context).OrTimeout();
Assert.True(onRejectedInvoked);
Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
}
[Fact]
public async void ExceptionThrownDuringOnRejected()
{
TaskCompletionSource<bool> tsc = new TaskCompletionSource<bool>();
var middleware = TestUtils.CreateTestMiddleware(
maxConcurrentRequests: 1,
requestQueueLimit: 0,
onRejected: httpContext =>
{
throw new DivideByZeroException();
},
next: httpContext =>
{
return tsc.Task;
});
var firstRequest = middleware.Invoke(new DefaultHttpContext());
var context = new DefaultHttpContext();
await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(context)).OrTimeout();
Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
tsc.SetResult(true);
Assert.True(firstRequest.IsCompletedSuccessfully);
var thirdRequest = middleware.Invoke(new DefaultHttpContext());
Assert.True(thirdRequest.IsCompletedSuccessfully);
Assert.Equal(0, middleware.ActiveRequestCount);
}
}
}

View File

@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
{
public static class TestUtils
{
public static RequestThrottlingMiddleware CreateTestMiddleware(int? maxConcurrentRequests, int requestQueueLimit = 5000, RequestDelegate next = null)
public static RequestThrottlingMiddleware CreateTestMiddleware(int? maxConcurrentRequests, int requestQueueLimit = 5000, RequestDelegate onRejected = null, RequestDelegate next = null)
{
var options = new RequestThrottlingOptions
{
@ -20,6 +20,11 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests
RequestQueueLimit = requestQueueLimit
};
if (onRejected != null)
{
options.OnRejected = onRejected;
}
return new RequestThrottlingMiddleware(
next: next ?? (context => Task.CompletedTask),
loggerFactory: NullLoggerFactory.Instance,