diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/Interfaces/IResponseCache.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/Interfaces/IResponseCache.cs index cd3b9da23f..41c85b277a 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Internal/Interfaces/IResponseCache.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/Interfaces/IResponseCache.cs @@ -8,7 +8,10 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal { public interface IResponseCache { + IResponseCacheEntry Get(string key); Task GetAsync(string key); + + void Set(string key, IResponseCacheEntry entry, TimeSpan validFor); Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor); } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCache.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCache.cs index 4023345682..42c613bc2c 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCache.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCache.cs @@ -4,7 +4,7 @@ using System; using System.Threading.Tasks; using Microsoft.Extensions.Caching.Memory; -using Microsoft.Net.Http.Headers; +using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.ResponseCaching.Internal { @@ -22,34 +22,39 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal _cache = cache; } - public Task GetAsync(string key) + public IResponseCacheEntry Get(string key) { var entry = _cache.Get(key); var memoryCachedResponse = entry as MemoryCachedResponse; if (memoryCachedResponse != null) { - return Task.FromResult(new CachedResponse + return new CachedResponse { Created = memoryCachedResponse.Created, StatusCode = memoryCachedResponse.StatusCode, Headers = memoryCachedResponse.Headers, Body = new SegmentReadStream(memoryCachedResponse.BodySegments, memoryCachedResponse.BodyLength) - }); + }; } else { - return Task.FromResult(entry as IResponseCacheEntry); + return entry as IResponseCacheEntry; } } - public async Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) + public Task GetAsync(string key) + { + return Task.FromResult(Get(key)); + } + + public void Set(string key, IResponseCacheEntry entry, TimeSpan validFor) { var cachedResponse = entry as CachedResponse; if (cachedResponse != null) { var segmentStream = new SegmentWriteStream(StreamUtilities.BodySegmentSize); - await cachedResponse.Body.CopyToAsync(segmentStream); + cachedResponse.Body.CopyTo(segmentStream); _cache.Set( key, @@ -77,5 +82,11 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal }); } } + + public Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) + { + Set(key, entry, validFor); + return TaskCache.CompletedTask; + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Microsoft.AspNetCore.ResponseCaching.csproj b/src/Microsoft.AspNetCore.ResponseCaching/Microsoft.AspNetCore.ResponseCaching.csproj index aa20232fa8..b56169a56a 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Microsoft.AspNetCore.ResponseCaching.csproj +++ b/src/Microsoft.AspNetCore.ResponseCaching/Microsoft.AspNetCore.ResponseCaching.csproj @@ -17,7 +17,7 @@ - + diff --git a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs index 798679eebc..cdb6cb817f 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs @@ -25,7 +25,6 @@ namespace Microsoft.AspNetCore.ResponseCaching private readonly IResponseCachingPolicyProvider _policyProvider; private readonly IResponseCache _cache; private readonly IResponseCachingKeyProvider _keyProvider; - private readonly Func _onStartingCallback; public ResponseCachingMiddleware( RequestDelegate next, @@ -66,7 +65,6 @@ namespace Microsoft.AspNetCore.ResponseCaching _policyProvider = policyProvider; _cache = cache; _keyProvider = keyProvider; - _onStartingCallback = state => OnResponseStartingAsync((ResponseCachingContext)state); } public async Task Invoke(HttpContext httpContext) @@ -90,13 +88,10 @@ namespace Microsoft.AspNetCore.ResponseCaching try { - // Subscribe to OnStarting event - httpContext.Response.OnStarting(_onStartingCallback, context); - await _next(httpContext); // If there was no response body, check the response headers now. We can cache things like redirects. - await OnResponseStartingAsync(context); + await StartResponseAsync(context); // Finalize the cache entry await FinalizeCacheBodyAsync(context); @@ -219,10 +214,17 @@ namespace Microsoft.AspNetCore.ResponseCaching return false; } - internal async Task FinalizeCacheHeadersAsync(ResponseCachingContext context) + + /// + /// Finalize cache headers. + /// + /// + /// true if a vary by entry needs to be stored in the cache; otherwise false. + private bool OnFinalizeCacheHeaders(ResponseCachingContext context) { if (_policyProvider.IsResponseCacheable(context)) { + var storeVaryByEntry = false; context.ShouldCacheResponse = true; // Create the cache entry now @@ -262,7 +264,7 @@ namespace Microsoft.AspNetCore.ResponseCaching // Always overwrite the CachedVaryByRules to update the expiry information _logger.LogVaryByRulesUpdated(normalizedVaryHeaders, normalizedVaryQueryKeys); - await _cache.SetAsync(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor); + storeVaryByEntry = true; context.StorageVaryKey = _keyProvider.CreateStorageVaryByKey(context); } @@ -290,13 +292,31 @@ namespace Microsoft.AspNetCore.ResponseCaching context.CachedResponse.Headers[header.Key] = header.Value; } } + + return storeVaryByEntry; } - else + + context.ResponseCachingStream.DisableBuffering(); + return false; + } + + internal void FinalizeCacheHeaders(ResponseCachingContext context) + { + if (OnFinalizeCacheHeaders(context)) { - context.ResponseCachingStream.DisableBuffering(); + _cache.Set(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor); } } + internal Task FinalizeCacheHeadersAsync(ResponseCachingContext context) + { + if (OnFinalizeCacheHeaders(context)) + { + return _cache.SetAsync(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor); + } + return TaskCache.CompletedTask; + } + internal async Task FinalizeCacheBodyAsync(ResponseCachingContext context) { if (context.ShouldCacheResponse && context.ResponseCachingStream.BufferingEnabled) @@ -327,19 +347,38 @@ namespace Microsoft.AspNetCore.ResponseCaching } } - internal Task OnResponseStartingAsync(ResponseCachingContext context) + /// + /// Mark the response as started and set the response time if no reponse was started yet. + /// + /// + /// true if the response was not started before this call; otherwise false. + private bool OnStartResponse(ResponseCachingContext context) { if (!context.ResponseStarted) { context.ResponseStarted = true; context.ResponseTime = _options.SystemClock.UtcNow; + return true; + } + return false; + } + + internal void StartResponse(ResponseCachingContext context) + { + if (OnStartResponse(context)) + { + FinalizeCacheHeaders(context); + } + } + + internal Task StartResponseAsync(ResponseCachingContext context) + { + if (OnStartResponse(context)) + { return FinalizeCacheHeadersAsync(context); } - else - { - return TaskCache.CompletedTask; - } + return TaskCache.CompletedTask; } internal static void AddResponseCachingFeature(HttpContext context) @@ -355,7 +394,12 @@ namespace Microsoft.AspNetCore.ResponseCaching { // Shim response stream context.OriginalResponseStream = context.HttpContext.Response.Body; - context.ResponseCachingStream = new ResponseCachingStream(context.OriginalResponseStream, _options.MaximumBodySize, StreamUtilities.BodySegmentSize); + context.ResponseCachingStream = new ResponseCachingStream( + context.OriginalResponseStream, + _options.MaximumBodySize, + StreamUtilities.BodySegmentSize, + () => StartResponse(context), + () => StartResponseAsync(context)); context.HttpContext.Response.Body = context.ResponseCachingStream; // Shim IHttpSendFileFeature diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCachingStream.cs b/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCachingStream.cs index 6644063b88..aa5fc371eb 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCachingStream.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCachingStream.cs @@ -14,12 +14,16 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal private readonly long _maxBufferSize; private readonly int _segmentSize; private SegmentWriteStream _segmentWriteStream; + private Action _startResponseCallback; + private Func _startResponseCallbackAsync; - internal ResponseCachingStream(Stream innerStream, long maxBufferSize, int segmentSize) + internal ResponseCachingStream(Stream innerStream, long maxBufferSize, int segmentSize, Action startResponseCallback, Func startResponseCallbackAsync) { _innerStream = innerStream; _maxBufferSize = maxBufferSize; _segmentSize = segmentSize; + _startResponseCallback = startResponseCallback; + _startResponseCallbackAsync = startResponseCallbackAsync; _segmentWriteStream = new SegmentWriteStream(_segmentSize); } @@ -71,10 +75,32 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } public override void Flush() - => _innerStream.Flush(); + { + try + { + _startResponseCallback(); + _innerStream.Flush(); + } + catch + { + DisableBuffering(); + throw; + } + } - public override Task FlushAsync(CancellationToken cancellationToken) - => _innerStream.FlushAsync(); + public override async Task FlushAsync(CancellationToken cancellationToken) + { + try + { + await _startResponseCallbackAsync(); + await _innerStream.FlushAsync(); + } + catch + { + DisableBuffering(); + throw; + } + } // Underlying stream is write-only, no need to override other read related methods public override int Read(byte[] buffer, int offset, int count) @@ -84,6 +110,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal { try { + _startResponseCallback(); _innerStream.Write(buffer, offset, count); } catch @@ -109,6 +136,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal { try { + await _startResponseCallbackAsync(); await _innerStream.WriteAsync(buffer, offset, count, cancellationToken); } catch diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs index 3cda27ef28..dfd9c12be9 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs @@ -359,7 +359,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests } [Fact] - public async Task OnResponseStartingAsync_IfAllowResponseCaptureIsTrue_SetsResponseTime() + public async Task StartResponsegAsync_IfAllowResponseCaptureIsTrue_SetsResponseTime() { var clock = new TestClock { @@ -372,13 +372,13 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests var context = TestUtils.CreateTestContext(); context.ResponseTime = null; - await middleware.OnResponseStartingAsync(context); + await middleware.StartResponseAsync(context); Assert.Equal(clock.UtcNow, context.ResponseTime); } [Fact] - public async Task OnResponseStartingAsync_IfAllowResponseCaptureIsTrue_SetsResponseTimeOnlyOnce() + public async Task StartResponseAsync_IfAllowResponseCaptureIsTrue_SetsResponseTimeOnlyOnce() { var clock = new TestClock { @@ -392,12 +392,12 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests var initialTime = clock.UtcNow; context.ResponseTime = null; - await middleware.OnResponseStartingAsync(context); + await middleware.StartResponseAsync(context); Assert.Equal(initialTime, context.ResponseTime); clock.UtcNow += TimeSpan.FromSeconds(10); - await middleware.OnResponseStartingAsync(context); + await middleware.StartResponseAsync(context); Assert.NotEqual(clock.UtcNow, context.ResponseTime); Assert.Equal(initialTime, context.ResponseTime); } @@ -790,10 +790,9 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests var middleware = TestUtils.CreateTestMiddleware(testSink: sink, cache: cache); var context = TestUtils.CreateTestContext(); - context.ShouldCacheResponse = false; middleware.ShimResponseStream(context); await context.HttpContext.Response.WriteAsync(new string('0', 10)); - + context.ShouldCacheResponse = false; await middleware.FinalizeCacheBodyAsync(context); diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingTests.cs index e8347bf11a..25fc1360e8 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -233,11 +232,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryHeader_Matches() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.Vary] = HeaderNames.From); foreach (var builder in builders) { @@ -256,11 +251,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfVaryHeader_Mismatches() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.Vary] = HeaderNames.From); foreach (var builder in builders) { @@ -280,11 +271,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryQueryKeys_Matches() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "query" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "query" }); foreach (var builder in builders) { @@ -302,11 +289,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryQueryKeysExplicit_Matches_QueryKeyCaseInsensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "QueryA", "queryb" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "QueryA", "queryb" }); foreach (var builder in builders) { @@ -324,11 +307,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryQueryKeyStar_Matches_QueryKeyCaseInsensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "*" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "*" }); foreach (var builder in builders) { @@ -346,11 +325,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryQueryKeyExplicit_Matches_OrderInsensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "QueryB", "QueryA" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "QueryB", "QueryA" }); foreach (var builder in builders) { @@ -368,11 +343,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfVaryQueryKeyStar_Matches_OrderInsensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "*" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "*" }); foreach (var builder in builders) { @@ -390,11 +361,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfVaryQueryKey_Mismatches() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "query" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "query" }); foreach (var builder in builders) { @@ -412,11 +379,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfVaryQueryKeyExplicit_Mismatch_QueryKeyCaseSensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "QueryA", "QueryB" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "QueryA", "QueryB" }); foreach (var builder in builders) { @@ -434,11 +397,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfVaryQueryKeyStar_Mismatch_QueryKeyValueCaseSensitive() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Features.Get().VaryByQueryKeys = new[] { "*" }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Features.Get().VaryByQueryKeys = new[] { "*" }); foreach (var builder in builders) { @@ -501,11 +460,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfSetCookie_IsSpecified() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - var headers = context.Response.Headers[HeaderNames.SetCookie] = "cookieName=cookieValue"; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.SetCookie] = "cookieName=cookieValue"); foreach (var builder in builders) { @@ -557,11 +512,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests await next.Invoke(); }); }, - requestDelegate: async (context) => - { - await context.Features.Get().SendFileAsync("dummy", 0, 0, CancellationToken.None); - await TestUtils.TestRequestDelegate(context); - }); + contextAction: async context => await context.Features.Get().SendFileAsync("dummy", 0, 0, CancellationToken.None)); foreach (var builder in builders) { @@ -623,14 +574,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfInitialResponseContainsNoStore() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - var headers = context.Response.GetTypedHeaders().CacheControl = new CacheControlHeaderValue() - { - NoStore = true - }; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.CacheControl] = CacheControlHeaderValue.NoStoreString); foreach (var builder in builders) { @@ -687,11 +631,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void Serves304_IfIfNoneMatch_Satisfied() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - var headers = context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\""); - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\"")); foreach (var builder in builders) { @@ -711,11 +651,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfIfNoneMatch_NotSatisfied() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - var headers = context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\""); - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\"")); foreach (var builder in builders) { @@ -797,11 +733,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_WithoutReplacingCachedVaryBy_OnCacheMiss() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.Vary] = HeaderNames.From); foreach (var builder in builders) { @@ -823,11 +755,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesFreshContent_IfCachedVaryByUpdated_OnCacheMiss() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]); foreach (var builder in builders) { @@ -858,11 +786,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfCachedVaryByNotUpdated_OnCacheMiss() { - var builders = TestUtils.CreateBuildersWithResponseCaching(requestDelegate: async (context) => - { - context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]; - await TestUtils.TestRequestDelegate(context); - }); + var builders = TestUtils.CreateBuildersWithResponseCaching(contextAction: context => context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]); foreach (var builder in builders) { diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs index 88d1511375..fa18b927fc 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Net.Http; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; @@ -32,7 +33,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests StreamUtilities.BodySegmentSize = 10; } - internal static RequestDelegate TestRequestDelegate = async context => + private static bool TestRequestDelegate(HttpContext context, string guid) { var headers = context.Response.GetTypedHeaders(); @@ -42,7 +43,6 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests headers.Expires = DateTimeOffset.Now.AddSeconds(int.Parse(expires)); } - var uniqueId = Guid.NewGuid().ToString(); if (headers.CacheControl == null) { headers.CacheControl = new CacheControlHeaderValue @@ -57,13 +57,33 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests headers.CacheControl.MaxAge = string.IsNullOrEmpty(expires) ? TimeSpan.FromSeconds(10) : (TimeSpan?)null; } headers.Date = DateTimeOffset.UtcNow; - headers.Headers["X-Value"] = uniqueId; + headers.Headers["X-Value"] = guid; if (context.Request.Method != "HEAD") + { + return true; + } + return false; + } + + internal static async Task TestRequestDelegateWriteAsync(HttpContext context) + { + var uniqueId = Guid.NewGuid().ToString(); + if (TestRequestDelegate(context, uniqueId)) { await context.Response.WriteAsync(uniqueId); } - }; + } + + internal static Task TestRequestDelegateWrite(HttpContext context) + { + var uniqueId = Guid.NewGuid().ToString(); + if (TestRequestDelegate(context, uniqueId)) + { + context.Response.Write(uniqueId); + } + return TaskCache.CompletedTask; + } internal static IResponseCachingKeyProvider CreateTestKeyProvider() { @@ -78,37 +98,64 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests internal static IEnumerable CreateBuildersWithResponseCaching( Action configureDelegate = null, ResponseCachingOptions options = null, - RequestDelegate requestDelegate = null) + Action contextAction = null) + { + return CreateBuildersWithResponseCaching(configureDelegate, options, new RequestDelegate[] + { + context => + { + contextAction?.Invoke(context); + return TestRequestDelegateWrite(context); + }, + context => + { + contextAction?.Invoke(context); + return TestRequestDelegateWriteAsync(context); + }, + }); + } + + private static IEnumerable CreateBuildersWithResponseCaching( + Action configureDelegate = null, + ResponseCachingOptions options = null, + IEnumerable requestDelegates = null) { if (configureDelegate == null) { configureDelegate = app => { }; } - if (requestDelegate == null) + if (requestDelegates == null) { - requestDelegate = TestRequestDelegate; + requestDelegates = new RequestDelegate[] + { + TestRequestDelegateWriteAsync, + TestRequestDelegateWrite + }; } - // Test with in memory ResponseCache - yield return new WebHostBuilder() - .ConfigureServices(services => - { - services.AddResponseCaching(responseCachingOptions => + foreach (var requestDelegate in requestDelegates) + { + // Test with in memory ResponseCache + yield return new WebHostBuilder() + .ConfigureServices(services => { - if (options != null) + services.AddResponseCaching(responseCachingOptions => { - responseCachingOptions.MaximumBodySize = options.MaximumBodySize; - responseCachingOptions.UseCaseSensitivePaths = options.UseCaseSensitivePaths; - responseCachingOptions.SystemClock = options.SystemClock; - } + if (options != null) + { + responseCachingOptions.MaximumBodySize = options.MaximumBodySize; + responseCachingOptions.UseCaseSensitivePaths = options.UseCaseSensitivePaths; + responseCachingOptions.SystemClock = options.SystemClock; + } + }); + }) + .Configure(app => + { + configureDelegate(app); + app.UseResponseCaching(); + app.Run(requestDelegate); }); - }) - .Configure(app => - { - configureDelegate(app); - app.UseResponseCaching(); - app.Run(requestDelegate); - }); + } } internal static ResponseCachingMiddleware CreateTestMiddleware( @@ -181,6 +228,25 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests } } + internal static class HttpResponseWritingExtensions + { + internal static void Write(this HttpResponse response, string text) + { + if (response == null) + { + throw new ArgumentNullException(nameof(response)); + } + + if (text == null) + { + throw new ArgumentNullException(nameof(text)); + } + + byte[] data = Encoding.UTF8.GetBytes(text); + response.Body.Write(data, 0, data.Length); + } + } + internal class LoggedMessage { internal static LoggedMessage RequestMethodNotCacheable => new LoggedMessage(1, LogLevel.Debug); @@ -289,23 +355,33 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests public int GetCount { get; private set; } public int SetCount { get; private set; } - public Task GetAsync(string key) + public IResponseCacheEntry Get(string key) { GetCount++; try { - return Task.FromResult(_storage[key]); + return _storage[key]; } catch { - return Task.FromResult(null); + return null; } } + public Task GetAsync(string key) + { + return Task.FromResult(Get(key)); + } + + public void Set(string key, IResponseCacheEntry entry, TimeSpan validFor) + { + SetCount++; + _storage[key] = entry; + } + public Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) { - SetCount++; - _storage[key] = entry; + Set(key, entry, validFor); return TaskCache.CompletedTask; } }