409 lines
16 KiB
C#
409 lines
16 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Globalization;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNetCore.Builder;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.AspNetCore.Http.Features;
|
|
using Microsoft.AspNetCore.Http.Headers;
|
|
using Microsoft.AspNetCore.ResponseCaching.Internal;
|
|
using Microsoft.Extensions.Internal;
|
|
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Options;
|
|
using Microsoft.Extensions.Primitives;
|
|
using Microsoft.Net.Http.Headers;
|
|
|
|
namespace Microsoft.AspNetCore.ResponseCaching
|
|
{
|
|
public class ResponseCacheMiddleware
|
|
{
|
|
private static readonly TimeSpan DefaultExpirationTimeSpan = TimeSpan.FromSeconds(10);
|
|
|
|
private readonly RequestDelegate _next;
|
|
private readonly ResponseCacheOptions _options;
|
|
private readonly ILogger _logger;
|
|
private readonly IResponseCachePolicyProvider _policyProvider;
|
|
private readonly IResponseCacheStore _store;
|
|
private readonly IResponseCacheKeyProvider _keyProvider;
|
|
private readonly Func<object, Task> _onStartingCallback;
|
|
|
|
public ResponseCacheMiddleware(
|
|
RequestDelegate next,
|
|
IOptions<ResponseCacheOptions> options,
|
|
ILoggerFactory loggerFactory,
|
|
IResponseCachePolicyProvider policyProvider,
|
|
IResponseCacheStore store,
|
|
IResponseCacheKeyProvider keyProvider)
|
|
{
|
|
if (next == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(next));
|
|
}
|
|
if (options == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(options));
|
|
}
|
|
if (loggerFactory == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(loggerFactory));
|
|
}
|
|
if (policyProvider == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(policyProvider));
|
|
}
|
|
if (store == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(store));
|
|
}
|
|
if (keyProvider == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(keyProvider));
|
|
}
|
|
|
|
_next = next;
|
|
_options = options.Value;
|
|
_logger = loggerFactory.CreateLogger<ResponseCacheMiddleware>();
|
|
_policyProvider = policyProvider;
|
|
_store = store;
|
|
_keyProvider = keyProvider;
|
|
_onStartingCallback = state => OnResponseStartingAsync((ResponseCacheContext)state);
|
|
}
|
|
|
|
public async Task Invoke(HttpContext httpContext)
|
|
{
|
|
var context = new ResponseCacheContext(httpContext, _logger);
|
|
|
|
// Should we attempt any caching logic?
|
|
if (_policyProvider.IsRequestCacheable(context))
|
|
{
|
|
// Can this request be served from cache?
|
|
if (await TryServeFromCacheAsync(context))
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Hook up to listen to the response stream
|
|
ShimResponseStream(context);
|
|
|
|
try
|
|
{
|
|
// Subscribe to OnStarting event
|
|
httpContext.Response.OnStarting(_onStartingCallback, context);
|
|
|
|
await _next(httpContext);
|
|
|
|
// If there was no response body, check the response headers now. We can cache things like redirects.
|
|
await OnResponseStartingAsync(context);
|
|
|
|
// Finalize the cache entry
|
|
await FinalizeCacheBodyAsync(context);
|
|
}
|
|
finally
|
|
{
|
|
UnshimResponseStream(context);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
await _next(httpContext);
|
|
}
|
|
}
|
|
|
|
internal async Task<bool> TryServeCachedResponseAsync(ResponseCacheContext context, CachedResponse cachedResponse)
|
|
{
|
|
context.CachedResponse = cachedResponse;
|
|
context.CachedResponseHeaders = new ResponseHeaders(cachedResponse.Headers);
|
|
context.ResponseTime = _options.SystemClock.UtcNow;
|
|
var cachedEntryAge = context.ResponseTime.Value - context.CachedResponse.Created;
|
|
context.CachedEntryAge = cachedEntryAge > TimeSpan.Zero ? cachedEntryAge : TimeSpan.Zero;
|
|
|
|
if (_policyProvider.IsCachedEntryFresh(context))
|
|
{
|
|
// Check conditional request rules
|
|
if (ContentIsNotModified(context))
|
|
{
|
|
_logger.LogNotModifiedServed();
|
|
context.HttpContext.Response.StatusCode = StatusCodes.Status304NotModified;
|
|
}
|
|
else
|
|
{
|
|
var response = context.HttpContext.Response;
|
|
// Copy the cached status code and response headers
|
|
response.StatusCode = context.CachedResponse.StatusCode;
|
|
foreach (var header in context.CachedResponse.Headers)
|
|
{
|
|
response.Headers.Add(header);
|
|
}
|
|
|
|
response.Headers[HeaderNames.Age] = context.CachedEntryAge.Value.TotalSeconds.ToString("F0", CultureInfo.InvariantCulture);
|
|
|
|
// Copy the cached response body
|
|
var body = context.CachedResponse.Body;
|
|
if (body.Length > 0)
|
|
{
|
|
// Add a content-length if required
|
|
if (!response.ContentLength.HasValue && StringValues.IsNullOrEmpty(response.Headers[HeaderNames.TransferEncoding]))
|
|
{
|
|
response.ContentLength = body.Length;
|
|
}
|
|
|
|
try
|
|
{
|
|
await body.CopyToAsync(response.Body, StreamUtilities.BodySegmentSize, context.HttpContext.RequestAborted);
|
|
}
|
|
catch (OperationCanceledException)
|
|
{
|
|
context.HttpContext.Abort();
|
|
}
|
|
}
|
|
_logger.LogCachedResponseServed();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
internal async Task<bool> TryServeFromCacheAsync(ResponseCacheContext context)
|
|
{
|
|
context.BaseKey = _keyProvider.CreateBaseKey(context);
|
|
var cacheEntry = await _store.GetAsync(context.BaseKey);
|
|
|
|
if (cacheEntry is CachedVaryByRules)
|
|
{
|
|
// Request contains vary rules, recompute key(s) and try again
|
|
context.CachedVaryByRules = (CachedVaryByRules)cacheEntry;
|
|
|
|
foreach (var varyKey in _keyProvider.CreateLookupVaryByKeys(context))
|
|
{
|
|
cacheEntry = await _store.GetAsync(varyKey);
|
|
|
|
if (cacheEntry is CachedResponse && await TryServeCachedResponseAsync(context, (CachedResponse)cacheEntry))
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
else if (cacheEntry is CachedResponse && await TryServeCachedResponseAsync(context, (CachedResponse)cacheEntry))
|
|
{
|
|
return true;
|
|
}
|
|
|
|
if (context.RequestCacheControlHeaderValue.OnlyIfCached)
|
|
{
|
|
_logger.LogGatewayTimeoutServed();
|
|
context.HttpContext.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
|
return true;
|
|
}
|
|
|
|
_logger.LogNoResponseServed();
|
|
return false;
|
|
}
|
|
|
|
internal async Task FinalizeCacheHeadersAsync(ResponseCacheContext context)
|
|
{
|
|
if (_policyProvider.IsResponseCacheable(context))
|
|
{
|
|
context.ShouldCacheResponse = true;
|
|
|
|
// Create the cache entry now
|
|
var response = context.HttpContext.Response;
|
|
var varyHeaders = new StringValues(response.Headers.GetCommaSeparatedValues(HeaderNames.Vary));
|
|
var varyQueryKeys = context.HttpContext.GetResponseCacheFeature()?.VaryByQueryKeys ?? StringValues.Empty;
|
|
context.CachedResponseValidFor = context.ResponseCacheControlHeaderValue.SharedMaxAge ??
|
|
context.ResponseCacheControlHeaderValue.MaxAge ??
|
|
(context.ResponseExpires - context.ResponseTime.Value) ??
|
|
DefaultExpirationTimeSpan;
|
|
|
|
// Check if any vary rules exist
|
|
if (!StringValues.IsNullOrEmpty(varyHeaders) || !StringValues.IsNullOrEmpty(varyQueryKeys))
|
|
{
|
|
// Normalize order and casing of vary by rules
|
|
var normalizedVaryHeaders = GetOrderCasingNormalizedStringValues(varyHeaders);
|
|
var normalizedVaryQueryKeys = GetOrderCasingNormalizedStringValues(varyQueryKeys);
|
|
|
|
// Update vary rules if they are different
|
|
if (context.CachedVaryByRules == null ||
|
|
!StringValues.Equals(context.CachedVaryByRules.QueryKeys, normalizedVaryQueryKeys) ||
|
|
!StringValues.Equals(context.CachedVaryByRules.Headers, normalizedVaryHeaders))
|
|
{
|
|
context.CachedVaryByRules = new CachedVaryByRules
|
|
{
|
|
VaryByKeyPrefix = FastGuid.NewGuid().IdString,
|
|
Headers = normalizedVaryHeaders,
|
|
QueryKeys = normalizedVaryQueryKeys
|
|
};
|
|
}
|
|
|
|
// Always overwrite the CachedVaryByRules to update the expiry information
|
|
_logger.LogVaryByRulesUpdated(normalizedVaryHeaders, normalizedVaryQueryKeys);
|
|
await _store.SetAsync(context.BaseKey, context.CachedVaryByRules, context.CachedResponseValidFor);
|
|
|
|
context.StorageVaryKey = _keyProvider.CreateStorageVaryByKey(context);
|
|
}
|
|
|
|
// Ensure date header is set
|
|
if (!context.ResponseDate.HasValue)
|
|
{
|
|
context.ResponseDate = context.ResponseTime.Value;
|
|
// Setting the date on the raw response headers.
|
|
context.TypedResponseHeaders.Date = context.ResponseDate;
|
|
}
|
|
|
|
// Store the response on the state
|
|
context.CachedResponse = new CachedResponse
|
|
{
|
|
Created = context.ResponseDate.Value,
|
|
StatusCode = context.HttpContext.Response.StatusCode
|
|
};
|
|
|
|
foreach (var header in context.TypedResponseHeaders.Headers)
|
|
{
|
|
if (!string.Equals(header.Key, HeaderNames.Age, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
context.CachedResponse.Headers.Add(header);
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
context.ResponseCacheStream.DisableBuffering();
|
|
}
|
|
}
|
|
|
|
internal async Task FinalizeCacheBodyAsync(ResponseCacheContext context)
|
|
{
|
|
var contentLength = context.TypedResponseHeaders.ContentLength;
|
|
if (context.ShouldCacheResponse && context.ResponseCacheStream.BufferingEnabled)
|
|
{
|
|
var bufferStream = context.ResponseCacheStream.GetBufferStream();
|
|
if (!contentLength.HasValue || contentLength == bufferStream.Length)
|
|
{
|
|
context.CachedResponse.Body = bufferStream;
|
|
_logger.LogResponseCached();
|
|
await _store.SetAsync(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor);
|
|
}
|
|
else
|
|
{
|
|
_logger.LogResponseContentLengthMismatchNotCached();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
_logger.LogResponseNotCached();
|
|
}
|
|
}
|
|
|
|
internal Task OnResponseStartingAsync(ResponseCacheContext context)
|
|
{
|
|
if (!context.ResponseStarted)
|
|
{
|
|
context.ResponseStarted = true;
|
|
context.ResponseTime = _options.SystemClock.UtcNow;
|
|
|
|
return FinalizeCacheHeadersAsync(context);
|
|
}
|
|
else
|
|
{
|
|
return TaskCache.CompletedTask;
|
|
}
|
|
}
|
|
|
|
internal void ShimResponseStream(ResponseCacheContext context)
|
|
{
|
|
// Shim response stream
|
|
context.OriginalResponseStream = context.HttpContext.Response.Body;
|
|
context.ResponseCacheStream = new ResponseCacheStream(context.OriginalResponseStream, _options.MaximumBodySize, StreamUtilities.BodySegmentSize);
|
|
context.HttpContext.Response.Body = context.ResponseCacheStream;
|
|
|
|
// Shim IHttpSendFileFeature
|
|
context.OriginalSendFileFeature = context.HttpContext.Features.Get<IHttpSendFileFeature>();
|
|
if (context.OriginalSendFileFeature != null)
|
|
{
|
|
context.HttpContext.Features.Set<IHttpSendFileFeature>(new SendFileFeatureWrapper(context.OriginalSendFileFeature, context.ResponseCacheStream));
|
|
}
|
|
|
|
context.HttpContext.AddResponseCacheFeature();
|
|
}
|
|
|
|
internal static void UnshimResponseStream(ResponseCacheContext context)
|
|
{
|
|
// Unshim response stream
|
|
context.HttpContext.Response.Body = context.OriginalResponseStream;
|
|
|
|
// Unshim IHttpSendFileFeature
|
|
context.HttpContext.Features.Set(context.OriginalSendFileFeature);
|
|
|
|
context.HttpContext.RemoveResponseCacheFeature();
|
|
}
|
|
|
|
internal static bool ContentIsNotModified(ResponseCacheContext context)
|
|
{
|
|
var cachedResponseHeaders = context.CachedResponseHeaders;
|
|
var ifNoneMatchHeader = context.TypedRequestHeaders.IfNoneMatch;
|
|
|
|
if (ifNoneMatchHeader != null)
|
|
{
|
|
if (ifNoneMatchHeader.Count == 1 && ifNoneMatchHeader[0].Equals(EntityTagHeaderValue.Any))
|
|
{
|
|
context.Logger.LogNotModifiedIfNoneMatchStar();
|
|
return true;
|
|
}
|
|
|
|
if (cachedResponseHeaders.ETag != null)
|
|
{
|
|
foreach (var tag in ifNoneMatchHeader)
|
|
{
|
|
if (cachedResponseHeaders.ETag.Compare(tag, useStrongComparison: false))
|
|
{
|
|
context.Logger.LogNotModifiedIfNoneMatchMatched(tag);
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
var ifUnmodifiedSince = context.TypedRequestHeaders.IfUnmodifiedSince;
|
|
if (ifUnmodifiedSince != null)
|
|
{
|
|
var lastModified = cachedResponseHeaders.LastModified ?? cachedResponseHeaders.Date;
|
|
if (lastModified <= ifUnmodifiedSince)
|
|
{
|
|
context.Logger.LogNotModifiedIfUnmodifiedSinceSatisfied(lastModified.Value, ifUnmodifiedSince.Value);
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// Normalize order and casing
|
|
internal static StringValues GetOrderCasingNormalizedStringValues(StringValues stringValues)
|
|
{
|
|
if (stringValues.Count == 1)
|
|
{
|
|
return new StringValues(stringValues.ToString().ToUpperInvariant());
|
|
}
|
|
else
|
|
{
|
|
var originalArray = stringValues.ToArray();
|
|
var newArray = new string[originalArray.Length];
|
|
|
|
for (var i = 0; i < originalArray.Length; i++)
|
|
{
|
|
newArray[i] = originalArray[i].ToUpperInvariant();
|
|
}
|
|
|
|
// Since the casing has already been normalized, use Ordinal comparison
|
|
Array.Sort(newArray, StringComparer.Ordinal);
|
|
|
|
return new StringValues(newArray);
|
|
}
|
|
}
|
|
}
|
|
}
|