From 1dff8cbdbcd4bd84b6eab4d28371aa8cead09540 Mon Sep 17 00:00:00 2001 From: Kahbazi Date: Fri, 7 Jun 2019 22:45:48 +0430 Subject: [PATCH] Add OnRejected to Request Throttling middleware (#10817) * Add OnRejected to Request Throttling middleware * Remove IApplicationBuilder extension with option * Add test * Add summary for OnRejected --- ...NetCore.RequestThrottling.netcoreapp3.0.cs | 1 + .../src/RequestThrottlingMiddleware.cs | 17 ++++-- .../src/RequestThrottlingOptions.cs | 11 ++++ .../RequestThrottling/test/MiddlewareTests.cs | 60 ++++++++++++++++++- .../RequestThrottling/test/TestUtils.cs | 7 ++- 5 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs index ff97f9feaf..ae4e1ab1a7 100644 --- a/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs +++ b/src/Middleware/RequestThrottling/ref/Microsoft.AspNetCore.RequestThrottling.netcoreapp3.0.cs @@ -20,6 +20,7 @@ namespace Microsoft.AspNetCore.RequestThrottling { public RequestThrottlingOptions() { } public int? MaxConcurrentRequests { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public Microsoft.AspNetCore.Http.RequestDelegate OnRejected { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public int RequestQueueLimit { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } } diff --git a/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs b/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs index d96dee1862..4aa7e9c9be 100644 --- a/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs +++ b/src/Middleware/RequestThrottling/src/RequestThrottlingMiddleware.cs @@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.RequestThrottling { private readonly RequestQueue _requestQueue; private readonly RequestDelegate _next; + private readonly RequestThrottlingOptions _requestThrottlingOptions; private readonly ILogger _logger; /// @@ -28,20 +29,27 @@ namespace Microsoft.AspNetCore.RequestThrottling /// The containing the initialization parameters. public RequestThrottlingMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IOptions options) { - if (options.Value.MaxConcurrentRequests == null) + _requestThrottlingOptions = options.Value; + + if (_requestThrottlingOptions.MaxConcurrentRequests == null) { throw new ArgumentException("The value of 'options.MaxConcurrentRequests' must be specified.", nameof(options)); } - if (options.Value.RequestQueueLimit < 0) + if (_requestThrottlingOptions.RequestQueueLimit < 0) { throw new ArgumentException("The value of 'options.RequestQueueLimit' must be a positive integer.", nameof(options)); } + if (_requestThrottlingOptions.OnRejected == null) + { + throw new ArgumentException("The value of 'options.OnRejected' must not be null.", nameof(options)); + } + _next = next; _logger = loggerFactory.CreateLogger(); _requestQueue = new RequestQueue( - options.Value.MaxConcurrentRequests.Value, - options.Value.RequestQueueLimit); + _requestThrottlingOptions.MaxConcurrentRequests.Value, + _requestThrottlingOptions.RequestQueueLimit); } /// @@ -56,6 +64,7 @@ namespace Microsoft.AspNetCore.RequestThrottling { RequestThrottlingLog.RequestRejectedQueueFull(_logger); context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable; + await _requestThrottlingOptions.OnRejected(context); return; } else if (!waitInQueueTask.IsCompletedSuccessfully) diff --git a/src/Middleware/RequestThrottling/src/RequestThrottlingOptions.cs b/src/Middleware/RequestThrottling/src/RequestThrottlingOptions.cs index 03b52d4502..c0e03ae2a1 100644 --- a/src/Middleware/RequestThrottling/src/RequestThrottlingOptions.cs +++ b/src/Middleware/RequestThrottling/src/RequestThrottlingOptions.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.RequestThrottling; namespace Microsoft.AspNetCore.RequestThrottling @@ -22,5 +24,14 @@ namespace Microsoft.AspNetCore.RequestThrottling /// Defaults to 5000 queued requests. /// public int RequestQueueLimit { get; set; } = 5000; + + /// + /// A that handles requests rejected by this middleware. + /// If it doesn't modify the response, an empty 503 response will be written. + /// + public RequestDelegate OnRejected { get; set; } = context => + { + return Task.CompletedTask; + }; } } diff --git a/src/Middleware/RequestThrottling/test/MiddlewareTests.cs b/src/Middleware/RequestThrottling/test/MiddlewareTests.cs index d5fa7b141c..aed8d8a963 100644 --- a/src/Middleware/RequestThrottling/test/MiddlewareTests.cs +++ b/src/Middleware/RequestThrottling/test/MiddlewareTests.cs @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests Assert.Equal(0, middleware.ActiveRequestCount); - await Assert.ThrowsAsync(() => middleware.Invoke(new DefaultHttpContext())); + await Assert.ThrowsAsync(() => middleware.Invoke(new DefaultHttpContext())).OrTimeout(); Assert.Equal(0, middleware.ActiveRequestCount); } @@ -93,7 +93,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests throw new NotImplementedException(); }); - await middleware.Invoke(new DefaultHttpContext()); + await middleware.Invoke(new DefaultHttpContext()).OrTimeout(); } [Fact] @@ -104,7 +104,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests requestQueueLimit: 0); var context = new DefaultHttpContext(); - await middleware.Invoke(context); + await middleware.Invoke(context).OrTimeout(); Assert.Equal(503, context.Response.StatusCode); } @@ -127,5 +127,59 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests _ = middleware.Invoke(new DefaultHttpContext()); Assert.Equal(2, middleware.ActiveRequestCount); } + + [Fact] + public async void FullQueueInvokesOnRejected() + { + bool onRejectedInvoked = false; + + var middleware = TestUtils.CreateTestMiddleware( + maxConcurrentRequests: 0, + requestQueueLimit: 0, + onRejected: httpContext => + { + onRejectedInvoked = true; + return Task.CompletedTask; + }); + + var context = new DefaultHttpContext(); + await middleware.Invoke(context).OrTimeout(); + Assert.True(onRejectedInvoked); + Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode); + } + + [Fact] + public async void ExceptionThrownDuringOnRejected() + { + TaskCompletionSource tsc = new TaskCompletionSource(); + + var middleware = TestUtils.CreateTestMiddleware( + maxConcurrentRequests: 1, + requestQueueLimit: 0, + onRejected: httpContext => + { + throw new DivideByZeroException(); + }, + next: httpContext => + { + return tsc.Task; + }); + + var firstRequest = middleware.Invoke(new DefaultHttpContext()); + + var context = new DefaultHttpContext(); + await Assert.ThrowsAsync(() => middleware.Invoke(context)).OrTimeout(); + Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode); + + tsc.SetResult(true); + + Assert.True(firstRequest.IsCompletedSuccessfully); + + var thirdRequest = middleware.Invoke(new DefaultHttpContext()); + + Assert.True(thirdRequest.IsCompletedSuccessfully); + + Assert.Equal(0, middleware.ActiveRequestCount); + } } } diff --git a/src/Middleware/RequestThrottling/test/TestUtils.cs b/src/Middleware/RequestThrottling/test/TestUtils.cs index 28e4c9fb14..9b5380f0ac 100644 --- a/src/Middleware/RequestThrottling/test/TestUtils.cs +++ b/src/Middleware/RequestThrottling/test/TestUtils.cs @@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests { public static class TestUtils { - public static RequestThrottlingMiddleware CreateTestMiddleware(int? maxConcurrentRequests, int requestQueueLimit = 5000, RequestDelegate next = null) + public static RequestThrottlingMiddleware CreateTestMiddleware(int? maxConcurrentRequests, int requestQueueLimit = 5000, RequestDelegate onRejected = null, RequestDelegate next = null) { var options = new RequestThrottlingOptions { @@ -20,6 +20,11 @@ namespace Microsoft.AspNetCore.RequestThrottling.Tests RequestQueueLimit = requestQueueLimit }; + if (onRejected != null) + { + options.OnRejected = onRejected; + } + return new RequestThrottlingMiddleware( next: next ?? (context => Task.CompletedTask), loggerFactory: NullLoggerFactory.Instance,