diff --git a/src/Middleware/ResponseCaching/src/ResponseCachingMiddleware.cs b/src/Middleware/ResponseCaching/src/ResponseCachingMiddleware.cs index 9bce6b921d..c64d7c27d8 100644 --- a/src/Middleware/ResponseCaching/src/ResponseCachingMiddleware.cs +++ b/src/Middleware/ResponseCaching/src/ResponseCachingMiddleware.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Logging; using Microsoft.Extensions.ObjectPool; @@ -345,7 +344,9 @@ namespace Microsoft.AspNetCore.ResponseCaching { var contentLength = context.HttpContext.Response.ContentLength; var bufferStream = context.ResponseCachingStream.GetBufferStream(); - if (!contentLength.HasValue || contentLength == bufferStream.Length) + if (!contentLength.HasValue || contentLength == bufferStream.Length + || (bufferStream.Length == 0 + && string.Equals(context.HttpContext.Request.Method, "HEAD", StringComparison.OrdinalIgnoreCase))) { var response = context.HttpContext.Response; // Add a content-length if required diff --git a/src/Middleware/ResponseCaching/test/ResponseCachingMiddlewareTests.cs b/src/Middleware/ResponseCaching/test/ResponseCachingMiddlewareTests.cs index a0332fb83b..68adab8c09 100644 --- a/src/Middleware/ResponseCaching/test/ResponseCachingMiddlewareTests.cs +++ b/src/Middleware/ResponseCaching/test/ResponseCachingMiddlewareTests.cs @@ -723,8 +723,10 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests LoggedMessage.ResponseCached); } - [Fact] - public async Task FinalizeCacheBody_DoNotCache_IfContentLengthMismatches() + [Theory] + [InlineData("GET")] + [InlineData("HEAD")] + public async Task FinalizeCacheBody_DoNotCache_IfContentLengthMismatches(string method) { var cache = new TestResponseCache(); var sink = new TestSink(); @@ -734,6 +736,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests context.ShouldCacheResponse = true; middleware.ShimResponseStream(context); context.HttpContext.Response.ContentLength = 9; + context.HttpContext.Request.Method = method; await context.HttpContext.Response.WriteAsync(new string('0', 10)); @@ -749,6 +752,39 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests LoggedMessage.ResponseContentLengthMismatchNotCached); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task FinalizeCacheBody_RequestHead_Cache_IfContentLengthPresent_AndBodyAbsentOrOfSameLength(bool includeBody) + { + var cache = new TestResponseCache(); + var sink = new TestSink(); + var middleware = TestUtils.CreateTestMiddleware(testSink: sink, cache: cache); + var context = TestUtils.CreateTestContext(); + + context.ShouldCacheResponse = true; + middleware.ShimResponseStream(context); + context.HttpContext.Response.ContentLength = 10; + context.HttpContext.Request.Method = "HEAD"; + + if (includeBody) + { + // A response to HEAD should not include a body, but it may be present + await context.HttpContext.Response.WriteAsync(new string('0', 10)); + } + + context.CachedResponse = new CachedResponse(); + context.BaseKey = "BaseKey"; + context.CachedResponseValidFor = TimeSpan.FromSeconds(10); + + middleware.FinalizeCacheBody(context); + + Assert.Equal(1, cache.SetCount); + TestUtils.AssertLoggedMessages( + sink.Writes, + LoggedMessage.ResponseCached); + } + [Fact] public async Task FinalizeCacheBody_Cache_IfContentLengthAbsent() { diff --git a/src/Middleware/ResponseCaching/test/ResponseCachingTests.cs b/src/Middleware/ResponseCaching/test/ResponseCachingTests.cs index 2a3bb3fde4..522b6952c2 100644 --- a/src/Middleware/ResponseCaching/test/ResponseCachingTests.cs +++ b/src/Middleware/ResponseCaching/test/ResponseCachingTests.cs @@ -2,13 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO; using System.Net.Http; -using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.TestHost; using Microsoft.Net.Http.Headers; using Xunit; @@ -775,6 +771,24 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests } } + [Fact] + public async Task ServesCachedContent_IfAvailable_UsingHead_WithContentLength() + { + var builders = TestUtils.CreateBuildersWithResponseCaching(); + + foreach (var builder in builders) + { + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.SendAsync(TestUtils.CreateRequest("HEAD", "?contentLength=10")); + var subsequentResponse = await client.SendAsync(TestUtils.CreateRequest("HEAD", "?contentLength=10")); + + await AssertCachedResponseAsync(initialResponse, subsequentResponse); + } + } + } + private static void Assert304Headers(HttpResponseMessage initialResponse, HttpResponseMessage subsequentResponse) { // https://tools.ietf.org/html/rfc7232#section-4.1 diff --git a/src/Middleware/ResponseCaching/test/TestUtils.cs b/src/Middleware/ResponseCaching/test/TestUtils.cs index de54e95da6..270778f996 100644 --- a/src/Middleware/ResponseCaching/test/TestUtils.cs +++ b/src/Middleware/ResponseCaching/test/TestUtils.cs @@ -4,11 +4,9 @@ using System; using System.Collections.Generic; using System.IO; -using System.IO.Pipelines; using System.Linq; using System.Net.Http; using System.Text; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; @@ -23,7 +21,6 @@ using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; using Xunit; -using ISystemClock = Microsoft.AspNetCore.ResponseCaching.ISystemClock; namespace Microsoft.AspNetCore.ResponseCaching.Tests { @@ -61,6 +58,12 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests headers.Date = DateTimeOffset.UtcNow; headers.Headers["X-Value"] = guid; + var contentLength = context.Request.Query["ContentLength"]; + if (!string.IsNullOrEmpty(contentLength)) + { + headers.ContentLength = long.Parse(contentLength); + } + if (context.Request.Method != "HEAD") { return true;