// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Globalization; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Headers; using Microsoft.AspNetCore.ResponseCaching.Internal; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; namespace Microsoft.AspNetCore.ResponseCaching { public class ResponseCacheMiddleware { private static readonly TimeSpan DefaultExpirationTimeSpan = TimeSpan.FromSeconds(10); private readonly RequestDelegate _next; private readonly ResponseCacheOptions _options; private readonly ILogger _logger; private readonly IResponseCachePolicyProvider _policyProvider; private readonly IResponseCacheStore _store; private readonly IResponseCacheKeyProvider _keyProvider; private readonly Func _onStartingCallback; public ResponseCacheMiddleware( RequestDelegate next, IOptions options, ILoggerFactory loggerFactory, IResponseCachePolicyProvider policyProvider, IResponseCacheStore store, IResponseCacheKeyProvider keyProvider) { if (next == null) { throw new ArgumentNullException(nameof(next)); } if (options == null) { throw new ArgumentNullException(nameof(options)); } if (loggerFactory == null) { throw new ArgumentNullException(nameof(loggerFactory)); } if (policyProvider == null) { throw new ArgumentNullException(nameof(policyProvider)); } if (store == null) { throw new ArgumentNullException(nameof(store)); } if (keyProvider == null) { throw new ArgumentNullException(nameof(keyProvider)); } _next = next; _options = options.Value; _logger = loggerFactory.CreateLogger(); _policyProvider = policyProvider; _store = store; _keyProvider = keyProvider; _onStartingCallback = state => OnResponseStartingAsync((ResponseCacheContext)state); } public async Task Invoke(HttpContext httpContext) { var context = new ResponseCacheContext(httpContext, _logger); // Should we attempt any caching logic? if (_policyProvider.IsRequestCacheable(context)) { // Can this request be served from cache? if (await TryServeFromCacheAsync(context)) { return; } // Hook up to listen to the response stream ShimResponseStream(context); try { // Subscribe to OnStarting event httpContext.Response.OnStarting(_onStartingCallback, context); await _next(httpContext); // If there was no response body, check the response headers now. We can cache things like redirects. await OnResponseStartingAsync(context); // Finalize the cache entry await FinalizeCacheBodyAsync(context); } finally { UnshimResponseStream(context); } } else { await _next(httpContext); } } internal async Task TryServeCachedResponseAsync(ResponseCacheContext context, CachedResponse cachedResponse) { context.CachedResponse = cachedResponse; context.CachedResponseHeaders = new ResponseHeaders(cachedResponse.Headers); context.ResponseTime = _options.SystemClock.UtcNow; var cachedEntryAge = context.ResponseTime.Value - context.CachedResponse.Created; context.CachedEntryAge = cachedEntryAge > TimeSpan.Zero ? cachedEntryAge : TimeSpan.Zero; if (_policyProvider.IsCachedEntryFresh(context)) { // Check conditional request rules if (ContentIsNotModified(context)) { _logger.LogNotModifiedServed(); context.HttpContext.Response.StatusCode = StatusCodes.Status304NotModified; } else { var response = context.HttpContext.Response; // Copy the cached status code and response headers response.StatusCode = context.CachedResponse.StatusCode; foreach (var header in context.CachedResponse.Headers) { response.Headers.Add(header); } response.Headers[HeaderNames.Age] = context.CachedEntryAge.Value.TotalSeconds.ToString("F0", CultureInfo.InvariantCulture); // Copy the cached response body var body = context.CachedResponse.Body; if (body.Length > 0) { // Add a content-length if required if (!response.ContentLength.HasValue && StringValues.IsNullOrEmpty(response.Headers[HeaderNames.TransferEncoding])) { response.ContentLength = body.Length; } try { await body.CopyToAsync(response.Body, StreamUtilities.BodySegmentSize, context.HttpContext.RequestAborted); } catch (OperationCanceledException) { context.HttpContext.Abort(); } } _logger.LogCachedResponseServed(); } return true; } return false; } internal async Task TryServeFromCacheAsync(ResponseCacheContext context) { context.BaseKey = _keyProvider.CreateBaseKey(context); var cacheEntry = await _store.GetAsync(context.BaseKey); if (cacheEntry is CachedVaryByRules) { // Request contains vary rules, recompute key(s) and try again context.CachedVaryByRules = (CachedVaryByRules)cacheEntry; foreach (var varyKey in _keyProvider.CreateLookupVaryByKeys(context)) { cacheEntry = await _store.GetAsync(varyKey); if (cacheEntry is CachedResponse && await TryServeCachedResponseAsync(context, (CachedResponse)cacheEntry)) { return true; } } } else if (cacheEntry is CachedResponse && await TryServeCachedResponseAsync(context, (CachedResponse)cacheEntry)) { return true; } if (context.RequestCacheControlHeaderValue.OnlyIfCached) { _logger.LogGatewayTimeoutServed(); context.HttpContext.Response.StatusCode = StatusCodes.Status504GatewayTimeout; return true; } _logger.LogNoResponseServed(); return false; } internal async Task FinalizeCacheHeadersAsync(ResponseCacheContext context) { if (_policyProvider.IsResponseCacheable(context)) { context.ShouldCacheResponse = true; // Create the cache entry now var response = context.HttpContext.Response; var varyHeaders = new StringValues(response.Headers.GetCommaSeparatedValues(HeaderNames.Vary)); var varyQueryKeys = context.HttpContext.GetResponseCacheFeature()?.VaryByQueryKeys ?? StringValues.Empty; context.CachedResponseValidFor = context.ResponseCacheControlHeaderValue.SharedMaxAge ?? context.ResponseCacheControlHeaderValue.MaxAge ?? (context.ResponseExpires - context.ResponseTime.Value) ?? DefaultExpirationTimeSpan; // Check if any vary rules exist if (!StringValues.IsNullOrEmpty(varyHeaders) || !StringValues.IsNullOrEmpty(varyQueryKeys)) { // Normalize order and casing of vary by rules var normalizedVaryHeaders = GetOrderCasingNormalizedStringValues(varyHeaders); var normalizedVaryQueryKeys = GetOrderCasingNormalizedStringValues(varyQueryKeys); // Update vary rules if they are different if (context.CachedVaryByRules == null || !StringValues.Equals(context.CachedVaryByRules.QueryKeys, normalizedVaryQueryKeys) || !StringValues.Equals(context.CachedVaryByRules.Headers, normalizedVaryHeaders)) { context.CachedVaryByRules = new CachedVaryByRules { VaryByKeyPrefix = FastGuid.NewGuid().IdString, Headers = normalizedVaryHeaders, QueryKeys = normalizedVaryQueryKeys }; } // Always overwrite the CachedVaryByRules to update the expiry information _logger.LogVaryByRulesUpdated(normalizedVaryHeaders, normalizedVaryQueryKeys); await _store.SetAsync(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor); context.StorageVaryKey = _keyProvider.CreateStorageVaryByKey(context); } // Ensure date header is set if (!context.ResponseDate.HasValue) { context.ResponseDate = context.ResponseTime.Value; // Setting the date on the raw response headers. context.TypedResponseHeaders.Date = context.ResponseDate; } // Store the response on the state context.CachedResponse = new CachedResponse { Created = context.ResponseDate.Value, StatusCode = context.HttpContext.Response.StatusCode }; foreach (var header in context.TypedResponseHeaders.Headers) { if (!string.Equals(header.Key, HeaderNames.Age, StringComparison.OrdinalIgnoreCase)) { context.CachedResponse.Headers.Add(header); } } } else { context.ResponseCacheStream.DisableBuffering(); } } internal async Task FinalizeCacheBodyAsync(ResponseCacheContext context) { var contentLength = context.TypedResponseHeaders.ContentLength; if (context.ShouldCacheResponse && context.ResponseCacheStream.BufferingEnabled) { var bufferStream = context.ResponseCacheStream.GetBufferStream(); if (!contentLength.HasValue || contentLength == bufferStream.Length) { context.CachedResponse.Body = bufferStream; _logger.LogResponseCached(); await _store.SetAsync(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor); } else { _logger.LogResponseContentLengthMismatchNotCached(); } } else { _logger.LogResponseNotCached(); } } internal Task OnResponseStartingAsync(ResponseCacheContext context) { if (!context.ResponseStarted) { context.ResponseStarted = true; context.ResponseTime = _options.SystemClock.UtcNow; return FinalizeCacheHeadersAsync(context); } else { return TaskCache.CompletedTask; } } internal void ShimResponseStream(ResponseCacheContext context) { // Shim response stream context.OriginalResponseStream = context.HttpContext.Response.Body; context.ResponseCacheStream = new ResponseCacheStream(context.OriginalResponseStream, _options.MaximumBodySize, StreamUtilities.BodySegmentSize); context.HttpContext.Response.Body = context.ResponseCacheStream; // Shim IHttpSendFileFeature context.OriginalSendFileFeature = context.HttpContext.Features.Get(); if (context.OriginalSendFileFeature != null) { context.HttpContext.Features.Set(new SendFileFeatureWrapper(context.OriginalSendFileFeature, context.ResponseCacheStream)); } context.HttpContext.AddResponseCacheFeature(); } internal static void UnshimResponseStream(ResponseCacheContext context) { // Unshim response stream context.HttpContext.Response.Body = context.OriginalResponseStream; // Unshim IHttpSendFileFeature context.HttpContext.Features.Set(context.OriginalSendFileFeature); context.HttpContext.RemoveResponseCacheFeature(); } internal static bool ContentIsNotModified(ResponseCacheContext context) { var cachedResponseHeaders = context.CachedResponseHeaders; var ifNoneMatchHeader = context.TypedRequestHeaders.IfNoneMatch; if (ifNoneMatchHeader != null) { if (ifNoneMatchHeader.Count == 1 && ifNoneMatchHeader[0].Equals(EntityTagHeaderValue.Any)) { context.Logger.LogNotModifiedIfNoneMatchStar(); return true; } if (cachedResponseHeaders.ETag != null) { foreach (var tag in ifNoneMatchHeader) { if (cachedResponseHeaders.ETag.Compare(tag, useStrongComparison: false)) { context.Logger.LogNotModifiedIfNoneMatchMatched(tag); return true; } } } } else { var ifUnmodifiedSince = context.TypedRequestHeaders.IfUnmodifiedSince; if (ifUnmodifiedSince != null) { var lastModified = cachedResponseHeaders.LastModified ?? cachedResponseHeaders.Date; if (lastModified <= ifUnmodifiedSince) { context.Logger.LogNotModifiedIfUnmodifiedSinceSatisfied(lastModified.Value, ifUnmodifiedSince.Value); return true; } } } return false; } // Normalize order and casing internal static StringValues GetOrderCasingNormalizedStringValues(StringValues stringValues) { if (stringValues.Count == 1) { return new StringValues(stringValues.ToString().ToUpperInvariant()); } else { var originalArray = stringValues.ToArray(); var newArray = new string[originalArray.Length]; for (var i = 0; i < originalArray.Length; i++) { newArray[i] = originalArray[i].ToUpperInvariant(); } // Since the casing has already been normalized, use Ordinal comparison Array.Sort(newArray, StringComparer.Ordinal); return new StringValues(newArray); } } } }