diff --git a/samples/ResponseCachingSample/Startup.cs b/samples/ResponseCachingSample/Startup.cs index 1cad8d6c1f..66708440a9 100644 --- a/samples/ResponseCachingSample/Startup.cs +++ b/samples/ResponseCachingSample/Startup.cs @@ -15,7 +15,7 @@ namespace ResponseCachingSample { public void ConfigureServices(IServiceCollection services) { - services.AddDistributedResponseCache(); + services.AddDistributedResponseCacheStore(); } public void Configure(IApplicationBuilder app) diff --git a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponse.cs b/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponse.cs index 7e32ec2959..1accf9a759 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponse.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponse.cs @@ -2,20 +2,19 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.ResponseCaching { - public class CachedResponse + public class CachedResponse : IResponseCacheEntry { - public string BodyKeyPrefix { get; set; } - public DateTimeOffset Created { get; set; } public int StatusCode { get; set; } public IHeaderDictionary Headers { get; set; } = new HeaderDictionary(); - public byte[] Body { get; set; } + public Stream Body { get; set; } } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedVaryByRules.cs b/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedVaryByRules.cs index c79d72ddc0..e0e82bd60a 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedVaryByRules.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedVaryByRules.cs @@ -5,7 +5,7 @@ using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.ResponseCaching { - public class CachedVaryByRules + public class CachedVaryByRules : IResponseCacheEntry { public string VaryByKeyPrefix { get; set; } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Extensions/ResponseCacheServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.ResponseCaching/Extensions/ResponseCacheServiceCollectionExtensions.cs index 51af53847e..e57dc5856f 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Extensions/ResponseCacheServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Extensions/ResponseCacheServiceCollectionExtensions.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.DependencyInjection { public static class ResponseCacheServiceCollectionExtensions { - public static IServiceCollection AddMemoryResponseCache(this IServiceCollection services) + public static IServiceCollection AddMemoryResponseCacheStore(this IServiceCollection services) { if (services == null) { @@ -24,7 +24,7 @@ namespace Microsoft.Extensions.DependencyInjection return services; } - public static IServiceCollection AddDistributedResponseCache(this IServiceCollection services) + public static IServiceCollection AddDistributedResponseCacheStore(this IServiceCollection services) { if (services == null) { diff --git a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponseBody.cs b/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheEntry.cs similarity index 75% rename from src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponseBody.cs rename to src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheEntry.cs index a5ce8d6aca..d09fd4da48 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CachedResponseBody.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheEntry.cs @@ -3,8 +3,7 @@ namespace Microsoft.AspNetCore.ResponseCaching { - public class CachedResponseBody + public interface IResponseCacheEntry { - public byte[] Body { get; set; } } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheStore.cs b/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheStore.cs index 55ed237786..9a12b9b6df 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheStore.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Interfaces/IResponseCacheStore.cs @@ -8,8 +8,7 @@ namespace Microsoft.AspNetCore.ResponseCaching { public interface IResponseCacheStore { - Task GetAsync(string key); - Task SetAsync(string key, object entry, TimeSpan validFor); - Task RemoveAsync(string key); + Task GetAsync(string key); + Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor); } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/DistributedResponseCacheStore.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/DistributedResponseCacheStore.cs index 170f7e65cd..ab708fc81b 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Internal/DistributedResponseCacheStore.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/DistributedResponseCacheStore.cs @@ -21,11 +21,11 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal _cache = cache; } - public async Task GetAsync(string key) + public async Task GetAsync(string key) { try { - return CacheEntrySerializer.Deserialize(await _cache.GetAsync(key)); + return ResponseCacheEntrySerializer.Deserialize(await _cache.GetAsync(key)); } catch { @@ -33,22 +33,13 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } } - public async Task RemoveAsync(string key) - { - try - { - await _cache.RemoveAsync(key); - } - catch { } - } - - public async Task SetAsync(string key, object entry, TimeSpan validFor) + public async Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) { try { await _cache.SetAsync( key, - CacheEntrySerializer.Serialize(entry), + ResponseCacheEntrySerializer.Serialize(entry), new DistributedCacheEntryOptions() { AbsoluteExpirationRelativeToNow = validFor diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryCachedResponse.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryCachedResponse.cs new file mode 100644 index 0000000000..d24e63a9ff --- /dev/null +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryCachedResponse.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.ResponseCaching.Internal +{ + internal class MemoryCachedResponse + { + public DateTimeOffset Created { get; set; } + + public int StatusCode { get; set; } + + public IHeaderDictionary Headers { get; set; } = new HeaderDictionary(); + + public List BodySegments { get; set; } + + public long BodyLength { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCacheStore.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCacheStore.cs index 33b7f4367b..1a49f1917e 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCacheStore.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/MemoryResponseCacheStore.cs @@ -4,7 +4,6 @@ using System; using System.Threading.Tasks; using Microsoft.Extensions.Caching.Memory; -using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.ResponseCaching.Internal { @@ -22,27 +21,60 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal _cache = cache; } - public Task GetAsync(string key) + public Task GetAsync(string key) { - return Task.FromResult(_cache.Get(key)); - } + var entry = _cache.Get(key); - public Task RemoveAsync(string key) - { - _cache.Remove(key); - return TaskCache.CompletedTask; - } - - public Task SetAsync(string key, object entry, TimeSpan validFor) - { - _cache.Set( - key, - entry, - new MemoryCacheEntryOptions() + if (entry is MemoryCachedResponse) + { + var memoryCachedResponse = (MemoryCachedResponse)entry; + return Task.FromResult(new CachedResponse() { - AbsoluteExpirationRelativeToNow = validFor + Created = memoryCachedResponse.Created, + StatusCode = memoryCachedResponse.StatusCode, + Headers = memoryCachedResponse.Headers, + Body = new SegmentReadStream(memoryCachedResponse.BodySegments, memoryCachedResponse.BodyLength) }); - return TaskCache.CompletedTask; + } + else + { + return Task.FromResult(entry as IResponseCacheEntry); + } + } + + public async Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) + { + if (entry is CachedResponse) + { + var cachedResponse = (CachedResponse)entry; + var segmentStream = new SegmentWriteStream(StreamUtilities.BodySegmentSize); + await cachedResponse.Body.CopyToAsync(segmentStream); + + _cache.Set( + key, + new MemoryCachedResponse() + { + Created = cachedResponse.Created, + StatusCode = cachedResponse.StatusCode, + Headers = cachedResponse.Headers, + BodySegments = segmentStream.GetSegments(), + BodyLength = segmentStream.Length + }, + new MemoryCacheEntryOptions() + { + AbsoluteExpirationRelativeToNow = validFor + }); + } + else + { + _cache.Set( + key, + entry, + new MemoryCacheEntryOptions() + { + AbsoluteExpirationRelativeToNow = validFor + }); + } } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CacheEntrySerializer.cs b/src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheEntrySerializer.cs similarity index 72% rename from src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CacheEntrySerializer.cs rename to src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheEntrySerializer.cs index 21534dd13f..46c679d241 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/CacheEntry/CacheEntrySerializer.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheEntrySerializer.cs @@ -7,11 +7,11 @@ using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.ResponseCaching.Internal { - internal static class CacheEntrySerializer + internal static class ResponseCacheEntrySerializer { private const int FormatVersion = 1; - public static object Deserialize(byte[] serializedEntry) + internal static IResponseCacheEntry Deserialize(byte[] serializedEntry) { if (serializedEntry == null) { @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } } - public static byte[] Serialize(object entry) + internal static byte[] Serialize(IResponseCacheEntry entry) { using (var memory = new MemoryStream()) { @@ -42,9 +42,9 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal // Serialization Format // Format version (int) - // Type (char: 'B' for CachedResponseBody, 'R' for CachedResponse, 'V' for CachedVaryByRules) + // Type (char: 'R' for CachedResponse, 'V' for CachedVaryByRules) // Type-dependent data (see CachedResponse and CachedVaryByRules) - public static object Read(BinaryReader reader) + private static IResponseCacheEntry Read(BinaryReader reader) { if (reader == null) { @@ -58,11 +58,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal var type = reader.ReadChar(); - if (type == 'B') - { - return ReadCachedResponseBody(reader); - } - else if (type == 'R') + if (type == 'R') { return ReadCachedResponse(reader); } @@ -76,18 +72,6 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } // Serialization Format - // Body length (int) - // Body (byte[]) - private static CachedResponseBody ReadCachedResponseBody(BinaryReader reader) - { - var bodyLength = reader.ReadInt32(); - var body = reader.ReadBytes(bodyLength); - - return new CachedResponseBody() { Body = body }; - } - - // Serialization Format - // BodyKeyPrefix (string) // Creation time - UtcTicks (long) // Status code (int) // Header count (int) @@ -96,12 +80,10 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal // ValueCount (int) // Value(s) // Value (string) - // ContainsBody (bool) - // Body length (int) - // Body (byte[]) + // BodyLength (int) + // Body (byte[]) private static CachedResponse ReadCachedResponse(BinaryReader reader) { - var bodyKeyPrefix = reader.ReadString(); var created = new DateTimeOffset(reader.ReadInt64(), TimeSpan.Zero); var statusCode = reader.ReadInt32(); var headerCount = reader.ReadInt32(); @@ -125,20 +107,20 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } } - var containsBody = reader.ReadBoolean(); - int bodyLength; - byte[] body = null; - if (containsBody) - { - bodyLength = reader.ReadInt32(); - body = reader.ReadBytes(bodyLength); - } + var bodyLength = reader.ReadInt32(); + var bodyBytes = reader.ReadBytes(bodyLength); - return new CachedResponse { BodyKeyPrefix = bodyKeyPrefix, Created = created, StatusCode = statusCode, Headers = headers, Body = body }; + return new CachedResponse + { + Created = created, + StatusCode = statusCode, + Headers = headers, + Body = new MemoryStream(bodyBytes, writable: false) + }; } // Serialization Format - // Guid (long) + // VaryKeyPrefix (string) // Headers count // Header(s) (comma separated string) // QueryKey count @@ -164,7 +146,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } // See serialization format above - public static void Write(BinaryWriter writer, object entry) + private static void Write(BinaryWriter writer, IResponseCacheEntry entry) { if (writer == null) { @@ -178,38 +160,25 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal writer.Write(FormatVersion); - if (entry is CachedResponseBody) - { - writer.Write('B'); - WriteCachedResponseBody(writer, entry as CachedResponseBody); - } - else if (entry is CachedResponse) + if (entry is CachedResponse) { writer.Write('R'); - WriteCachedResponse(writer, entry as CachedResponse); + WriteCachedResponse(writer, (CachedResponse)entry); } else if (entry is CachedVaryByRules) { writer.Write('V'); - WriteCachedVaryByRules(writer, entry as CachedVaryByRules); + WriteCachedVaryByRules(writer, (CachedVaryByRules)entry); } else { - throw new NotSupportedException($"Unrecognized entry format for {nameof(entry)}."); + throw new NotSupportedException($"Unrecognized entry type for {nameof(entry)}."); } } - // See serialization format above - private static void WriteCachedResponseBody(BinaryWriter writer, CachedResponseBody entry) - { - writer.Write(entry.Body.Length); - writer.Write(entry.Body); - } - // See serialization format above private static void WriteCachedResponse(BinaryWriter writer, CachedResponse entry) { - writer.Write(entry.BodyKeyPrefix); writer.Write(entry.Created.UtcTicks); writer.Write(entry.StatusCode); writer.Write(entry.Headers.Count); @@ -223,15 +192,39 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } } - if (entry.Body == null) + if (entry.Body.CanSeek) { - writer.Write(false); + if (entry.Body.Length > int.MaxValue) + { + throw new NotSupportedException($"{nameof(entry.Body)} is too large to serialized."); + } + + var bodyLength = (int)entry.Body.Length; + var bodyBytes = new byte[bodyLength]; + var bytesRead = entry.Body.Read(bodyBytes, 0, bodyLength); + + if (bytesRead != bodyLength) + { + throw new InvalidOperationException($"Failed to fully read {nameof(entry.Body)}."); + } + + writer.Write(bodyLength); + writer.Write(bodyBytes); } else { - writer.Write(true); - writer.Write(entry.Body.Length); - writer.Write(entry.Body); + var stream = new MemoryStream(); + entry.Body.CopyTo(stream); + + if (stream.Length > int.MaxValue) + { + throw new NotSupportedException($"{nameof(entry.Body)} is too large to serialized."); + } + + var bodyLength = (int)stream.Length; + writer.Write(bodyLength); + writer.Write(stream.ToArray()); + } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheMiddleware.cs b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheMiddleware.cs index 6ceb6ad8d0..72f6902751 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheMiddleware.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheMiddleware.cs @@ -130,16 +130,8 @@ namespace Microsoft.AspNetCore.ResponseCaching response.Headers[HeaderNames.Age] = context.CachedEntryAge.Value.TotalSeconds.ToString("F0", CultureInfo.InvariantCulture); - var body = context.CachedResponse.Body ?? - ((CachedResponseBody) await _store.GetAsync(context.CachedResponse.BodyKeyPrefix))?.Body; - - // If the body is not found, something went wrong. - if (body == null) - { - return false; - } - // Copy the cached response body + var body = context.CachedResponse.Body; if (body.Length > 0) { // Add a content-length if required @@ -147,7 +139,15 @@ namespace Microsoft.AspNetCore.ResponseCaching { response.ContentLength = body.Length; } - await response.Body.WriteAsync(body, 0, body.Length); + + try + { + await body.CopyToAsync(response.Body, StreamUtilities.BodySegmentSize, context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) + { + context.HttpContext.Abort(); + } } } @@ -244,7 +244,6 @@ namespace Microsoft.AspNetCore.ResponseCaching // Store the response on the state context.CachedResponse = new CachedResponse { - BodyKeyPrefix = FastGuid.NewGuid().IdString, Created = context.ResponseDate.Value, StatusCode = context.HttpContext.Response.StatusCode }; @@ -266,26 +265,12 @@ namespace Microsoft.AspNetCore.ResponseCaching internal async Task FinalizeCacheBodyAsync(ResponseCacheContext context) { var contentLength = context.TypedResponseHeaders.ContentLength; - if (context.ShouldCacheResponse && - context.ResponseCacheStream.BufferingEnabled && - (!contentLength.HasValue || contentLength == context.ResponseCacheStream.BufferedStream.Length)) + if (context.ShouldCacheResponse && context.ResponseCacheStream.BufferingEnabled) { - if (context.ResponseCacheStream.BufferedStream.Length >= _options.MinimumSplitBodySize) + var bufferStream = context.ResponseCacheStream.GetBufferStream(); + if (!contentLength.HasValue || contentLength == bufferStream.Length) { - // Store response and response body separately - await _store.SetAsync(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor); - - var cachedResponseBody = new CachedResponseBody() - { - Body = context.ResponseCacheStream.BufferedStream.ToArray() - }; - - await _store.SetAsync(context.CachedResponse.BodyKeyPrefix, cachedResponseBody, context.CachedResponseValidFor); - } - else - { - // Store response and response body together - context.CachedResponse.Body = context.ResponseCacheStream.BufferedStream.ToArray(); + context.CachedResponse.Body = bufferStream; await _store.SetAsync(context.StorageVaryKey ?? context.BaseKey, context.CachedResponse, context.CachedResponseValidFor); } } @@ -310,7 +295,7 @@ namespace Microsoft.AspNetCore.ResponseCaching { // Shim response stream context.OriginalResponseStream = context.HttpContext.Response.Body; - context.ResponseCacheStream = new ResponseCacheStream(context.OriginalResponseStream, _options.MaximumCachedBodySize); + context.ResponseCacheStream = new ResponseCacheStream(context.OriginalResponseStream, _options.MaximumBodySize, StreamUtilities.BodySegmentSize); context.HttpContext.Response.Body = context.ResponseCacheStream; // Shim IHttpSendFileFeature @@ -381,7 +366,7 @@ namespace Microsoft.AspNetCore.ResponseCaching var originalArray = stringValues.ToArray(); var newArray = new string[originalArray.Length]; - for (int i = 0; i < originalArray.Length; i++) + for (var i = 0; i < originalArray.Length; i++) { newArray[i] = originalArray[i].ToUpperInvariant(); } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheOptions.cs b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheOptions.cs index db9bae7575..795c80abcf 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheOptions.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCacheOptions.cs @@ -11,18 +11,13 @@ namespace Microsoft.AspNetCore.Builder /// /// The largest cacheable size for the response body in bytes. The default is set to 1 MB. /// - public long MaximumCachedBodySize { get; set; } = 1024 * 1024; + public long MaximumBodySize { get; set; } = 1024 * 1024; /// /// true if request paths are case-sensitive; otherwise false. The default is to treat paths as case-insensitive. /// public bool UseCaseSensitivePaths { get; set; } = false; - /// - /// The smallest size in bytes for which the headers and body of the response will be stored separately. The default is set to 70 KB. - /// - public long MinimumSplitBodySize { get; set; } = 70 * 1024; - /// /// For testing purposes only. /// diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheStream.cs b/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCacheStream.cs similarity index 70% rename from src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheStream.cs rename to src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCacheStream.cs index b8921b85ba..40fe217aec 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/Internal/ResponseCacheStream.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/Streams/ResponseCacheStream.cs @@ -12,16 +12,18 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal { private readonly Stream _innerStream; private readonly long _maxBufferSize; + private readonly int _segmentSize; + private SegmentWriteStream _segmentWriteStream; - public ResponseCacheStream(Stream innerStream, long maxBufferSize) + internal ResponseCacheStream(Stream innerStream, long maxBufferSize, int segmentSize) { _innerStream = innerStream; _maxBufferSize = maxBufferSize; + _segmentSize = segmentSize; + _segmentWriteStream = new SegmentWriteStream(_segmentSize); } - public MemoryStream BufferedStream { get; } = new MemoryStream(); - - public bool BufferingEnabled { get; set; } = true; + internal bool BufferingEnabled { get; private set; } = true; public override bool CanRead => _innerStream.CanRead; @@ -34,15 +36,26 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal public override long Position { get { return _innerStream.Position; } - set { _innerStream.Position = value; } + set + { + DisableBuffering(); + _innerStream.Position = value; + } } - public void DisableBuffering() + internal Stream GetBufferStream() + { + if (!BufferingEnabled) + { + throw new InvalidOperationException("Buffer stream cannot be retrieved since buffering is disabled."); + } + return new SegmentReadStream(_segmentWriteStream.GetSegments(), _segmentWriteStream.Length); + } + + internal void DisableBuffering() { BufferingEnabled = false; - BufferedStream.SetLength(0); - BufferedStream.Capacity = 0; - BufferedStream.Dispose(); + _segmentWriteStream.Dispose(); } public override void SetLength(long value) @@ -81,13 +94,13 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal if (BufferingEnabled) { - if (BufferedStream.Length + count > _maxBufferSize) + if (_segmentWriteStream.Length + count > _maxBufferSize) { DisableBuffering(); } else { - BufferedStream.Write(buffer, offset, count); + _segmentWriteStream.Write(buffer, offset, count); } } } @@ -106,13 +119,13 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal if (BufferingEnabled) { - if (BufferedStream.Length + count > _maxBufferSize) + if (_segmentWriteStream.Length + count > _maxBufferSize) { DisableBuffering(); } else { - await BufferedStream.WriteAsync(buffer, offset, count, cancellationToken); + await _segmentWriteStream.WriteAsync(buffer, offset, count, cancellationToken); } } } @@ -131,13 +144,13 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal if (BufferingEnabled) { - if (BufferedStream.Length + 1 > _maxBufferSize) + if (_segmentWriteStream.Length + 1 > _maxBufferSize) { DisableBuffering(); } else { - BufferedStream.WriteByte(value); + _segmentWriteStream.WriteByte(value); } } } @@ -148,7 +161,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) #endif { - return ToIAsyncResult(WriteAsync(buffer, offset, count), callback, state); + return StreamUtilities.ToIAsyncResult(WriteAsync(buffer, offset, count), callback, state); } #if NETSTANDARD1_3 public void EndWrite(IAsyncResult asyncResult) @@ -162,28 +175,5 @@ namespace Microsoft.AspNetCore.ResponseCaching.Internal } ((Task)asyncResult).GetAwaiter().GetResult(); } - - private static IAsyncResult ToIAsyncResult(Task task, AsyncCallback callback, object state) - { - var tcs = new TaskCompletionSource(state); - task.ContinueWith(t => - { - if (t.IsFaulted) - { - tcs.TrySetException(t.Exception.InnerExceptions); - } - else if (t.IsCanceled) - { - tcs.TrySetCanceled(); - } - else - { - tcs.TrySetResult(0); - } - - callback?.Invoke(tcs.Task); - }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); - return tcs.Task; - } } } diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentReadStream.cs b/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentReadStream.cs new file mode 100644 index 0000000000..51ff15ff41 --- /dev/null +++ b/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentReadStream.cs @@ -0,0 +1,238 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.ResponseCaching.Internal +{ + internal class SegmentReadStream : Stream + { + private readonly List _segments; + private readonly long _length; + private int _segmentIndex; + private int _segmentOffset; + private long _position; + + internal SegmentReadStream(List segments, long length) + { + if (segments == null) + { + throw new ArgumentNullException(nameof(segments)); + } + + _segments = segments; + _length = length; + } + + public override bool CanRead => true; + + public override bool CanSeek => true; + + public override bool CanWrite => false; + + public override long Length => _length; + + public override long Position + { + get + { + return _position; + } + set + { + // The stream only supports a full rewind. This will need an update if random access becomes a required feature. + if (value != 0) + { + throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(Position)} can only be set to 0."); + } + + _position = 0; + _segmentOffset = 0; + _segmentIndex = 0; + } + } + + public override void Flush() + { + throw new NotSupportedException("The stream does not support writing."); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException(nameof(offset), offset, "Non-negative number required."); + } + // Read of length 0 will return zero and indicate end of stream. + if (count <= 0 ) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Positive number required."); + } + if (count > buffer.Length - offset) + { + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } + + if (_segmentIndex == _segments.Count) + { + return 0; + } + + var bytesRead = 0; + while (count > 0) + { + if (_segmentOffset == _segments[_segmentIndex].Length) + { + // Move to the next segment + _segmentIndex++; + _segmentOffset = 0; + + if (_segmentIndex == _segments.Count) + { + break; + } + } + + // Read up to the end of the segment + var segmentBytesRead = Math.Min(count, _segments[_segmentIndex].Length - _segmentOffset); + Buffer.BlockCopy(_segments[_segmentIndex], _segmentOffset, buffer, offset, segmentBytesRead); + bytesRead += segmentBytesRead; + _segmentOffset += segmentBytesRead; + _position += segmentBytesRead; + offset += segmentBytesRead; + count -= segmentBytesRead; + } + + return bytesRead; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return Task.FromResult(Read(buffer, offset, count)); + } + + public override int ReadByte() + { + if (Position == Length) + { + return -1; + } + + if (_segmentOffset == _segments[_segmentIndex].Length) + { + // Move to the next segment + _segmentIndex++; + _segmentOffset = 0; + } + + var byteRead = _segments[_segmentIndex][_segmentOffset]; + _segmentOffset++; + _position++; + + return byteRead; + } + +#if NETSTANDARD1_3 + public IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) +#else + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) +#endif + { + var tcs = new TaskCompletionSource(state); + + try + { + tcs.TrySetResult(Read(buffer, offset, count)); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + + if (callback != null) + { + // Offload callbacks to avoid stack dives on sync completions. + var ignored = Task.Run(() => + { + try + { + callback(tcs.Task); + } + catch (Exception) + { + // Suppress exceptions on background threads. + } + }); + } + + return tcs.Task; + } + +#if NETSTANDARD1_3 + public int EndRead(IAsyncResult asyncResult) +#else + public override int EndRead(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException(nameof(asyncResult)); + } + return ((Task)asyncResult).GetAwaiter().GetResult(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + // The stream only supports a full rewind. This will need an update if random access becomes a required feature. + if (origin != SeekOrigin.Begin) + { + throw new ArgumentException(nameof(origin), $"{nameof(Seek)} can only be set to {nameof(SeekOrigin.Begin)}."); + } + if (offset != 0) + { + throw new ArgumentOutOfRangeException(nameof(offset), offset, $"{nameof(Seek)} can only be set to 0."); + } + + Position = 0; + return Position; + } + + public override void SetLength(long value) + { + throw new NotSupportedException("The stream does not support writing."); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException("The stream does not support writing."); + } + + public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (destination == null) + { + throw new ArgumentNullException(nameof(destination)); + } + if (!destination.CanWrite) + { + throw new NotSupportedException("The destination stream does not support writing."); + } + + for (; _segmentIndex < _segments.Count; _segmentIndex++, _segmentOffset = 0) + { + cancellationToken.ThrowIfCancellationRequested(); + var bytesCopied = _segments[_segmentIndex].Length - _segmentOffset; + await destination.WriteAsync(_segments[_segmentIndex], _segmentOffset, bytesCopied, cancellationToken); + _position += bytesCopied; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentWriteStream.cs b/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentWriteStream.cs new file mode 100644 index 0000000000..6fb93c5c7b --- /dev/null +++ b/src/Microsoft.AspNetCore.ResponseCaching/Streams/SegmentWriteStream.cs @@ -0,0 +1,215 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.ResponseCaching.Internal +{ + internal class SegmentWriteStream : Stream + { + private readonly List _segments = new List(); + private readonly MemoryStream _bufferStream = new MemoryStream(); + private readonly int _segmentSize; + private long _length; + private bool _closed; + private bool _disposed; + + internal SegmentWriteStream(int segmentSize) + { + if (segmentSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(segmentSize), segmentSize, $"{nameof(segmentSize)} must be greater than 0."); + } + + _segmentSize = segmentSize; + } + + // Extracting the buffered segments closes the stream for writing + internal List GetSegments() + { + if (!_closed) + { + _closed = true; + FinalizeSegments(); + } + return _segments; + } + + public override bool CanRead => false; + + public override bool CanSeek => false; + + public override bool CanWrite => !_closed; + + public override long Length => _length; + + public override long Position + { + get + { + return _length; + } + set + { + throw new NotSupportedException("The stream does not support seeking."); + } + } + + private void DisposeMemoryStream() + { + // Clean up the memory stream + _bufferStream.SetLength(0); + _bufferStream.Capacity = 0; + _bufferStream.Dispose(); + } + + private void FinalizeSegments() + { + // Append any remaining segments + if (_bufferStream.Length > 0) + { + // Add the last segment + _segments.Add(_bufferStream.ToArray()); + } + + DisposeMemoryStream(); + } + + protected override void Dispose(bool disposing) + { + try + { + if (_disposed) + { + return; + } + + if (disposing) + { + _segments.Clear(); + DisposeMemoryStream(); + } + + _disposed = true; + _closed = true; + } + finally + { + base.Dispose(disposing); + } + } + + public override void Flush() + { + if (!CanWrite) + { + throw new ObjectDisposedException("The stream has been closed for writing."); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotSupportedException("The stream does not support reading."); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException("The stream does not support seeking."); + } + + public override void SetLength(long value) + { + throw new NotSupportedException("The stream does not support seeking."); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException(nameof(offset), offset, "Non-negative number required."); + } + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Non-negative number required."); + } + if (count > buffer.Length - offset) + { + throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } + if (!CanWrite) + { + throw new ObjectDisposedException("The stream has been closed for writing."); + } + + while (count > 0) + { + if ((int)_bufferStream.Length == _segmentSize) + { + _segments.Add(_bufferStream.ToArray()); + _bufferStream.SetLength(0); + } + + var bytesWritten = Math.Min(count, _segmentSize - (int)_bufferStream.Length); + + _bufferStream.Write(buffer, offset, bytesWritten); + count -= bytesWritten; + offset += bytesWritten; + _length += bytesWritten; + } + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Write(buffer, offset, count); + return TaskCache.CompletedTask; + } + + public override void WriteByte(byte value) + { + if (!CanWrite) + { + throw new ObjectDisposedException("The stream has been closed for writing."); + } + + if ((int)_bufferStream.Length == _segmentSize) + { + _segments.Add(_bufferStream.ToArray()); + _bufferStream.SetLength(0); + } + + _bufferStream.WriteByte(value); + _length++; + } + +#if NETSTANDARD1_3 + public IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) +#else + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) +#endif + { + return StreamUtilities.ToIAsyncResult(WriteAsync(buffer, offset, count), callback, state); + } + +#if NETSTANDARD1_3 + public void EndWrite(IAsyncResult asyncResult) +#else + public override void EndWrite(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException(nameof(asyncResult)); + } + ((Task)asyncResult).GetAwaiter().GetResult(); + } + } +} diff --git a/src/Microsoft.AspNetCore.ResponseCaching/Streams/StreamUtilities.cs b/src/Microsoft.AspNetCore.ResponseCaching/Streams/StreamUtilities.cs new file mode 100644 index 0000000000..4ce5a4ebe0 --- /dev/null +++ b/src/Microsoft.AspNetCore.ResponseCaching/Streams/StreamUtilities.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.ResponseCaching.Internal +{ + internal static class StreamUtilities + { + /// + /// The segment size for buffering the response body in bytes. The default is set to 84 KB. + /// + // Internal for testing + internal static int BodySegmentSize { get; set; } = 84 * 1024; + + internal static IAsyncResult ToIAsyncResult(Task task, AsyncCallback callback, object state) + { + var tcs = new TaskCompletionSource(state); + task.ContinueWith(t => + { + if (t.IsFaulted) + { + tcs.TrySetException(t.Exception.InnerExceptions); + } + else if (t.IsCanceled) + { + tcs.TrySetCanceled(); + } + else + { + tcs.TrySetResult(0); + } + + callback?.Invoke(tcs.Task); + }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); + return tcs.Task; + } + } +} diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/CacheEntrySerializerTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/CacheEntrySerializerTests.cs index 966fdee6e0..7664e437cb 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/CacheEntrySerializerTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/CacheEntrySerializerTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.AspNetCore.Http; @@ -16,47 +17,23 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public void Serialize_NullObject_Throws() { - Assert.Throws(() => CacheEntrySerializer.Serialize(null)); + Assert.Throws(() => ResponseCacheEntrySerializer.Serialize(null)); + } + + private class UnknownResponseCacheEntry : IResponseCacheEntry + { } [Fact] public void Serialize_UnknownObject_Throws() { - Assert.Throws(() => CacheEntrySerializer.Serialize(new object())); + Assert.Throws(() => ResponseCacheEntrySerializer.Serialize(new UnknownResponseCacheEntry())); } [Fact] public void Deserialize_NullObject_ReturnsNull() { - Assert.Null(CacheEntrySerializer.Deserialize(null)); - } - - [Fact] - public void RoundTrip_CachedResponseBody_Succeeds() - { - var cachedResponseBody = new CachedResponseBody() - { - Body = Encoding.ASCII.GetBytes("Hello world"), - }; - - AssertCachedResponseBodyEqual(cachedResponseBody, (CachedResponseBody)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedResponseBody))); - } - - [Fact] - public void RoundTrip_CachedResponseWithoutBody_Succeeds() - { - var headers = new HeaderDictionary(); - headers["keyA"] = "valueA"; - headers["keyB"] = "valueB"; - var cachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString, - Created = DateTimeOffset.UtcNow, - StatusCode = StatusCodes.Status200OK, - Headers = headers - }; - - AssertCachedResponseEqual(cachedResponse, (CachedResponse)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedResponse))); + Assert.Null(ResponseCacheEntrySerializer.Deserialize(null)); } [Fact] @@ -65,16 +42,16 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests var headers = new HeaderDictionary(); headers["keyA"] = "valueA"; headers["keyB"] = "valueB"; + var body = Encoding.ASCII.GetBytes("Hello world"); var cachedResponse = new CachedResponse() { - BodyKeyPrefix = FastGuid.NewGuid().IdString, Created = DateTimeOffset.UtcNow, StatusCode = StatusCodes.Status200OK, - Body = Encoding.ASCII.GetBytes("Hello world"), + Body = new SegmentReadStream(new List(new[] { body }), body.Length), Headers = headers }; - AssertCachedResponseEqual(cachedResponse, (CachedResponse)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedResponse))); + AssertCachedResponseEqual(cachedResponse, (CachedResponse)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedResponse))); } [Fact] @@ -82,16 +59,16 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests { var headers = new HeaderDictionary(); headers["keyA"] = new StringValues(new[] { "ValueA", "ValueB" }); + var body = Encoding.ASCII.GetBytes("Hello world"); var cachedResponse = new CachedResponse() { - BodyKeyPrefix = FastGuid.NewGuid().IdString, Created = DateTimeOffset.UtcNow, StatusCode = StatusCodes.Status200OK, - Body = Encoding.ASCII.GetBytes("Hello world"), + Body = new SegmentReadStream(new List(new[] { body }), body.Length), Headers = headers }; - AssertCachedResponseEqual(cachedResponse, (CachedResponse)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedResponse))); + AssertCachedResponseEqual(cachedResponse, (CachedResponse)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedResponse))); } [Fact] @@ -99,16 +76,16 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests { var headers = new HeaderDictionary(); headers["keyA"] = StringValues.Empty; + var body = Encoding.ASCII.GetBytes("Hello world"); var cachedResponse = new CachedResponse() { - BodyKeyPrefix = FastGuid.NewGuid().IdString, Created = DateTimeOffset.UtcNow, StatusCode = StatusCodes.Status200OK, - Body = Encoding.ASCII.GetBytes("Hello world"), + Body = new SegmentReadStream(new List(new[] { body }), body.Length), Headers = headers }; - AssertCachedResponseEqual(cachedResponse, (CachedResponse)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedResponse))); + AssertCachedResponseEqual(cachedResponse, (CachedResponse)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedResponse))); } [Fact] @@ -119,7 +96,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests VaryByKeyPrefix = FastGuid.NewGuid().IdString }; - AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedVaryByRule))); + AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedVaryByRule))); } [Fact] @@ -132,7 +109,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests Headers = headers }; - AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedVaryByRule))); + AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedVaryByRule))); } [Fact] @@ -145,7 +122,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests QueryKeys = queryKeys }; - AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedVaryByRule))); + AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedVaryByRule))); } [Fact] @@ -160,7 +137,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests QueryKeys = queryKeys }; - AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)CacheEntrySerializer.Deserialize(CacheEntrySerializer.Serialize(cachedVaryByRule))); + AssertCachedVaryByRuleEqual(cachedVaryByRule, (CachedVaryByRules)ResponseCacheEntrySerializer.Deserialize(ResponseCacheEntrySerializer.Serialize(cachedVaryByRule))); } [Fact] @@ -172,22 +149,16 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests VaryByKeyPrefix = FastGuid.NewGuid().IdString, Headers = headers }; - var serializedEntry = CacheEntrySerializer.Serialize(cachedVaryByRule); + var serializedEntry = ResponseCacheEntrySerializer.Serialize(cachedVaryByRule); Array.Reverse(serializedEntry); - Assert.Null(CacheEntrySerializer.Deserialize(serializedEntry)); - } - - private static void AssertCachedResponseBodyEqual(CachedResponseBody expected, CachedResponseBody actual) - { - Assert.True(expected.Body.SequenceEqual(actual.Body)); + Assert.Null(ResponseCacheEntrySerializer.Deserialize(serializedEntry)); } private static void AssertCachedResponseEqual(CachedResponse expected, CachedResponse actual) { Assert.NotNull(actual); Assert.NotNull(expected); - Assert.Equal(expected.BodyKeyPrefix, actual.BodyKeyPrefix); Assert.Equal(expected.Created, actual.Created); Assert.Equal(expected.StatusCode, actual.StatusCode); Assert.Equal(expected.Headers.Count, actual.Headers.Count); @@ -195,14 +166,15 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests { Assert.Equal(expectedHeader.Value, actual.Headers[expectedHeader.Key]); } - if (expected.Body == null) - { - Assert.Null(actual.Body); - } - else - { - Assert.True(expected.Body.SequenceEqual(actual.Body)); - } + + Assert.Equal(expected.Body.Length, actual.Body.Length); + var bodyLength = (int)expected.Body.Length; + var expectedBytes = new byte[bodyLength]; + var actualBytes = new byte[bodyLength]; + expected.Body.Position = 0; // Rewind + Assert.Equal(bodyLength, expected.Body.Read(expectedBytes, 0, bodyLength)); + Assert.Equal(bodyLength, actual.Body.Read(actualBytes, 0, bodyLength)); + Assert.True(expectedBytes.SequenceEqual(actualBytes)); } private static void AssertCachedVaryByRuleEqual(CachedVaryByRules expected, CachedVaryByRules actual) diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheMiddlewareTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheMiddlewareTests.cs index 5d8bebc4a4..792c6630e5 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheMiddlewareTests.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Headers; using Microsoft.AspNetCore.ResponseCaching.Internal; @@ -53,7 +52,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests "BaseKey", new CachedResponse() { - Body = new byte[0] + Body = new SegmentReadStream(new List(0), 0) }, TimeSpan.Zero); @@ -92,7 +91,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests "BaseKeyVaryKey2", new CachedResponse() { - Body = new byte[0] + Body = new SegmentReadStream(new List(0), 0) }, TimeSpan.Zero); @@ -424,78 +423,6 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests Assert.Equal(new StringValues(new[] { "HEADERA", "HEADERB" }), context.CachedVaryByRules.Headers); } - [Fact] - public async Task FinalizeCacheBody_StoreResponseBodySeparately_IfLargerThanLimit() - { - var store = new TestResponseCacheStore(); - var middleware = TestUtils.CreateTestMiddleware(store); - var context = TestUtils.CreateTestContext(); - - middleware.ShimResponseStream(context); - await context.HttpContext.Response.WriteAsync(new string('0', 70 * 1024)); - - context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; - context.BaseKey = "BaseKey"; - context.CachedResponseValidFor = TimeSpan.FromSeconds(10); - - await middleware.FinalizeCacheBodyAsync(context); - - Assert.Equal(2, store.SetCount); - } - - [Fact] - public async Task FinalizeCacheBody_StoreResponseBodyInCachedResponse_IfSmallerThanLimit() - { - var store = new TestResponseCacheStore(); - var middleware = TestUtils.CreateTestMiddleware(store); - var context = TestUtils.CreateTestContext(); - - middleware.ShimResponseStream(context); - await context.HttpContext.Response.WriteAsync(new string('0', 70 * 1024 - 1)); - - context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; - context.BaseKey = "BaseKey"; - context.CachedResponseValidFor = TimeSpan.FromSeconds(10); - - await middleware.FinalizeCacheBodyAsync(context); - - Assert.Equal(1, store.SetCount); - } - - [Fact] - public async Task FinalizeCacheBody_StoreResponseBodySeparately_LimitIsConfigurable() - { - var store = new TestResponseCacheStore(); - var middleware = TestUtils.CreateTestMiddleware(store, new ResponseCacheOptions() - { - MinimumSplitBodySize = 2048 - }); - var context = TestUtils.CreateTestContext(); - - middleware.ShimResponseStream(context); - await context.HttpContext.Response.WriteAsync(new string('0', 1024)); - - context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; - context.BaseKey = "BaseKey"; - context.CachedResponseValidFor = TimeSpan.FromSeconds(10); - - await middleware.FinalizeCacheBodyAsync(context); - - Assert.Equal(1, store.SetCount); - } - [Fact] public async Task FinalizeCacheBody_Cache_IfContentLengthMatches() { @@ -504,14 +431,11 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests var context = TestUtils.CreateTestContext(); middleware.ShimResponseStream(context); - context.HttpContext.Response.ContentLength = 10; - await context.HttpContext.Response.WriteAsync(new string('0', 10)); + context.HttpContext.Response.ContentLength = 20; + await context.HttpContext.Response.WriteAsync(new string('0', 20)); context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; + context.CachedResponse = new CachedResponse(); context.BaseKey = "BaseKey"; context.CachedResponseValidFor = TimeSpan.FromSeconds(10); @@ -532,10 +456,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests await context.HttpContext.Response.WriteAsync(new string('0', 10)); context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; + context.CachedResponse = new CachedResponse(); context.BaseKey = "BaseKey"; context.CachedResponseValidFor = TimeSpan.FromSeconds(10); @@ -555,10 +476,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests await context.HttpContext.Response.WriteAsync(new string('0', 10)); context.ShouldCacheResponse = true; - context.CachedResponse = new CachedResponse() - { - BodyKeyPrefix = FastGuid.NewGuid().IdString - }; + context.CachedResponse = new CachedResponse(); context.BaseKey = "BaseKey"; context.CachedResponseValidFor = TimeSpan.FromSeconds(10); diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheTests.cs index 09f432d3b7..bc473d12a5 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCacheTests.cs @@ -19,288 +19,333 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests [Fact] public async void ServesCachedContent_IfAvailable() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfNotAvailable() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync("/different"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync("/different"); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryHeader_Matches() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.From = "user@example.com"; - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + client.DefaultRequestHeaders.From = "user@example.com"; + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfVaryHeader_Mismatches() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.From = "user@example.com"; - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user2@example.com"; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + client.DefaultRequestHeaders.From = "user@example.com"; + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user2@example.com"; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryQueryKeys_Matches() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = "query"; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?query=value"); - var subsequentResponse = await client.GetAsync("?query=value"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?query=value"); + var subsequentResponse = await client.GetAsync("?query=value"); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryQueryKeysExplicit_Matches_QueryKeyCaseInsensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "QueryA", "queryb" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); - var subsequentResponse = await client.GetAsync("?QueryA=valuea&QueryB=valueb"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); + var subsequentResponse = await client.GetAsync("?QueryA=valuea&QueryB=valueb"); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryQueryKeyStar_Matches_QueryKeyCaseInsensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "*" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); - var subsequentResponse = await client.GetAsync("?QueryA=valuea&QueryB=valueb"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); + var subsequentResponse = await client.GetAsync("?QueryA=valuea&QueryB=valueb"); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryQueryKeyExplicit_Matches_OrderInsensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "QueryB", "QueryA" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?QueryA=ValueA&QueryB=ValueB"); - var subsequentResponse = await client.GetAsync("?QueryB=ValueB&QueryA=ValueA"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?QueryA=ValueA&QueryB=ValueB"); + var subsequentResponse = await client.GetAsync("?QueryB=ValueB&QueryA=ValueA"); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfVaryQueryKeyStar_Matches_OrderInsensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "*" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?QueryA=ValueA&QueryB=ValueB"); - var subsequentResponse = await client.GetAsync("?QueryB=ValueB&QueryA=ValueA"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?QueryA=ValueA&QueryB=ValueB"); + var subsequentResponse = await client.GetAsync("?QueryB=ValueB&QueryA=ValueA"); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfVaryQueryKey_Mismatches() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = "query"; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?query=value"); - var subsequentResponse = await client.GetAsync("?query=value2"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?query=value"); + var subsequentResponse = await client.GetAsync("?query=value2"); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfVaryQueryKeyExplicit_Mismatch_QueryKeyCaseSensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "QueryA", "QueryB" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); - var subsequentResponse = await client.GetAsync("?querya=ValueA&queryb=ValueB"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); + var subsequentResponse = await client.GetAsync("?querya=ValueA&queryb=ValueB"); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfVaryQueryKeyStar_Mismatch_QueryKeyValueCaseSensitive() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.GetResponseCacheFeature().VaryByQueryKeys = new[] { "*" }; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); - var subsequentResponse = await client.GetAsync("?querya=ValueA&queryb=ValueB"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync("?querya=valuea&queryb=valueb"); + var subsequentResponse = await client.GetAsync("?querya=ValueA&queryb=ValueB"); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfRequestRequirements_NotMet() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + using (var server = new TestServer(builder)) { - MaxAge = TimeSpan.FromSeconds(0) - }; - var subsequentResponse = await client.GetAsync(""); + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + { + MaxAge = TimeSpan.FromSeconds(0) + }; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void Serves504_IfOnlyIfCachedHeader_IsSpecified() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + using (var server = new TestServer(builder)) { - OnlyIfCached = true - }; - var subsequentResponse = await client.GetAsync("/different"); + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + { + OnlyIfCached = true + }; + var subsequentResponse = await client.GetAsync("/different"); - initialResponse.EnsureSuccessStatusCode(); - Assert.Equal(System.Net.HttpStatusCode.GatewayTimeout, subsequentResponse.StatusCode); + initialResponse.EnsureSuccessStatusCode(); + Assert.Equal(System.Net.HttpStatusCode.GatewayTimeout, subsequentResponse.StatusCode); + } } } [Fact] public async void ServesFreshContent_IfSetCookie_IsSpecified() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { var headers = context.Response.Headers[HeaderNames.SetCookie] = "cookieName=cookieValue"; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfIHttpSendFileFeature_NotUsed() { - var builder = TestUtils.CreateBuilderWithResponseCache(app => + var builders = TestUtils.CreateBuildersWithResponseCache(app => { app.Use(async (context, next) => { @@ -309,20 +354,23 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests }); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfIHttpSendFileFeature_Used() { - var builder = TestUtils.CreateBuilderWithResponseCache( + var builders = TestUtils.CreateBuildersWithResponseCache( app => { app.Use(async (context, next) => @@ -337,248 +385,284 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfSubsequentRequest_ContainsNoStore() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + using (var server = new TestServer(builder)) { - NoStore = true - }; - var subsequentResponse = await client.GetAsync(""); + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + { + NoStore = true + }; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfInitialRequestContains_NoStore() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + using (var server = new TestServer(builder)) { - NoStore = true - }; - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + var client = server.CreateClient(); + client.DefaultRequestHeaders.CacheControl = new System.Net.Http.Headers.CacheControlHeaderValue() + { + NoStore = true + }; + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void Serves304_IfIfModifiedSince_Satisfied() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.IfUnmodifiedSince = DateTimeOffset.MaxValue; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.IfUnmodifiedSince = DateTimeOffset.MaxValue; + var subsequentResponse = await client.GetAsync(""); - initialResponse.EnsureSuccessStatusCode(); - Assert.Equal(System.Net.HttpStatusCode.NotModified, subsequentResponse.StatusCode); + initialResponse.EnsureSuccessStatusCode(); + Assert.Equal(System.Net.HttpStatusCode.NotModified, subsequentResponse.StatusCode); + } } } [Fact] public async void ServesCachedContent_IfIfModifiedSince_NotSatisfied() { - var builder = TestUtils.CreateBuilderWithResponseCache(); + var builders = TestUtils.CreateBuildersWithResponseCache(); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.IfUnmodifiedSince = DateTimeOffset.MinValue; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.IfUnmodifiedSince = DateTimeOffset.MinValue; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void Serves304_IfIfNoneMatch_Satisfied() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { var headers = context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\""); await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.IfNoneMatch.Add(new System.Net.Http.Headers.EntityTagHeaderValue("\"E1\"")); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.IfNoneMatch.Add(new System.Net.Http.Headers.EntityTagHeaderValue("\"E1\"")); + var subsequentResponse = await client.GetAsync(""); - initialResponse.EnsureSuccessStatusCode(); - Assert.Equal(System.Net.HttpStatusCode.NotModified, subsequentResponse.StatusCode); + initialResponse.EnsureSuccessStatusCode(); + Assert.Equal(System.Net.HttpStatusCode.NotModified, subsequentResponse.StatusCode); + } } } [Fact] public async void ServesCachedContent_IfIfNoneMatch_NotSatisfied() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { var headers = context.Response.GetTypedHeaders().ETag = new EntityTagHeaderValue("\"E1\""); await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.IfNoneMatch.Add(new System.Net.Http.Headers.EntityTagHeaderValue("\"E2\"")); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.IfNoneMatch.Add(new System.Net.Http.Headers.EntityTagHeaderValue("\"E2\"")); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfBodySize_IsCacheable() { - var builder = TestUtils.CreateBuilderWithResponseCache(options: new ResponseCacheOptions() + var builders = TestUtils.CreateBuildersWithResponseCache(options: new ResponseCacheOptions() { - MaximumCachedBodySize = 100 + MaximumBodySize = 100 }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfBodySize_IsNotCacheable() { - var builder = TestUtils.CreateBuilderWithResponseCache(options: new ResponseCacheOptions() + var builders = TestUtils.CreateBuildersWithResponseCache(options: new ResponseCacheOptions() { - MaximumCachedBodySize = 1 + MaximumBodySize = 1 }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - var initialResponse = await client.GetAsync(""); - var subsequentResponse = await client.GetAsync("/different"); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + var initialResponse = await client.GetAsync(""); + var subsequentResponse = await client.GetAsync("/different"); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_WithoutReplacingCachedVaryBy_OnCacheMiss() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.Response.Headers[HeaderNames.Vary] = HeaderNames.From; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.From = "user@example.com"; - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user2@example.com"; - var otherResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user@example.com"; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + client.DefaultRequestHeaders.From = "user@example.com"; + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user2@example.com"; + var otherResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user@example.com"; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesFreshContent_IfCachedVaryByUpdated_OnCacheMiss() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.From = "user@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); - client.DefaultRequestHeaders.MaxForwards = 1; - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user2@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("Max-Forwards")); - client.DefaultRequestHeaders.MaxForwards = 2; - var otherResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); - client.DefaultRequestHeaders.MaxForwards = 1; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + client.DefaultRequestHeaders.From = "user@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); + client.DefaultRequestHeaders.MaxForwards = 1; + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user2@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("Max-Forwards")); + client.DefaultRequestHeaders.MaxForwards = 2; + var otherResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); + client.DefaultRequestHeaders.MaxForwards = 1; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + await AssertResponseNotCachedAsync(initialResponse, subsequentResponse); + } } } [Fact] public async void ServesCachedContent_IfCachedVaryByNotUpdated_OnCacheMiss() { - var builder = TestUtils.CreateBuilderWithResponseCache(requestDelegate: async (context) => + var builders = TestUtils.CreateBuildersWithResponseCache(requestDelegate: async (context) => { context.Response.Headers[HeaderNames.Vary] = context.Request.Headers[HeaderNames.Pragma]; await TestUtils.TestRequestDelegate(context); }); - using (var server = new TestServer(builder)) + foreach (var builder in builders) { - var client = server.CreateClient(); - client.DefaultRequestHeaders.From = "user@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); - client.DefaultRequestHeaders.MaxForwards = 1; - var initialResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user2@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); - client.DefaultRequestHeaders.MaxForwards = 2; - var otherResponse = await client.GetAsync(""); - client.DefaultRequestHeaders.From = "user@example.com"; - client.DefaultRequestHeaders.Pragma.Clear(); - client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); - client.DefaultRequestHeaders.MaxForwards = 1; - var subsequentResponse = await client.GetAsync(""); + using (var server = new TestServer(builder)) + { + var client = server.CreateClient(); + client.DefaultRequestHeaders.From = "user@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); + client.DefaultRequestHeaders.MaxForwards = 1; + var initialResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user2@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); + client.DefaultRequestHeaders.MaxForwards = 2; + var otherResponse = await client.GetAsync(""); + client.DefaultRequestHeaders.From = "user@example.com"; + client.DefaultRequestHeaders.Pragma.Clear(); + client.DefaultRequestHeaders.Pragma.Add(new System.Net.Http.Headers.NameValueHeaderValue("From")); + client.DefaultRequestHeaders.MaxForwards = 1; + var subsequentResponse = await client.GetAsync(""); - await AssertResponseCachedAsync(initialResponse, subsequentResponse); + await AssertResponseCachedAsync(initialResponse, subsequentResponse); + } } } diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentReadStreamTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentReadStreamTests.cs new file mode 100644 index 0000000000..5247df3096 --- /dev/null +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentReadStreamTests.cs @@ -0,0 +1,285 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.AspNetCore.ResponseCaching.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.ResponseCaching.Tests +{ + public class SegmentReadStreamTests + { + public class TestStreamInitInfo + { + internal List Segments { get; set; } + internal int SegmentSize { get; set; } + internal long Length { get; set; } + } + + public static TheoryData TestStreams + { + get + { + return new TheoryData + { + // Partial Segment + new TestStreamInitInfo() + { + Segments = new List(new[] + { + new byte[] { 0, 1, 2, 3, 4 }, + new byte[] { 5, 6, 7, 8, 9 }, + new byte[] { 10, 11, 12 }, + }), + SegmentSize = 5, + Length = 13 + }, + // Full Segments + new TestStreamInitInfo() + { + Segments = new List(new[] + { + new byte[] { 0, 1, 2, 3, 4 }, + new byte[] { 5, 6, 7, 8, 9 }, + new byte[] { 10, 11, 12, 13, 14 }, + }), + SegmentSize = 5, + Length = 15 + } + }; + } + } + + [Fact] + public void SegmentReadStream_NullSegments_Throws() + { + Assert.Throws(() => new SegmentReadStream(null, 0)); + } + + [Fact] + public void Position_ResetToZero_Succeeds() + { + var stream = new SegmentReadStream(new List(), 0); + + // This should not throw + stream.Position = 0; + } + + [Theory] + [InlineData(1)] + [InlineData(-1)] + [InlineData(100)] + [InlineData(long.MaxValue)] + [InlineData(long.MinValue)] + public void Position_SetToNonZero_Throws(long position) + { + var stream = new SegmentReadStream(new List(new[] { new byte[100] }), 100); + + Assert.Throws(() => stream.Position = position); + } + + [Fact] + public void WriteOperations_Throws() + { + var stream = new SegmentReadStream(new List(), 0); + + + Assert.Throws(() => stream.Flush()); + Assert.Throws(() => stream.Write(new byte[1], 0, 0)); + } + + [Fact] + public void SetLength_Throws() + { + var stream = new SegmentReadStream(new List(), 0); + + Assert.Throws(() => stream.SetLength(0)); + } + + [Theory] + [InlineData(SeekOrigin.Current)] + [InlineData(SeekOrigin.End)] + public void Seek_NotBegin_Throws(SeekOrigin origin) + { + var stream = new SegmentReadStream(new List(), 0); + + Assert.Throws(() => stream.Seek(0, origin)); + } + + [Theory] + [InlineData(1)] + [InlineData(-1)] + [InlineData(100)] + [InlineData(long.MaxValue)] + [InlineData(long.MinValue)] + public void Seek_NotZero_Throws(long offset) + { + var stream = new SegmentReadStream(new List(), 0); + + Assert.Throws(() => stream.Seek(offset, SeekOrigin.Begin)); + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void ReadByte_CanReadAllBytes(TestStreamInitInfo info) + { + var stream = new SegmentReadStream(info.Segments, info.Length); + + for (var i = 0; i < stream.Length; i++) + { + Assert.Equal(i, stream.Position); + Assert.Equal(i, stream.ReadByte()); + } + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(-1, stream.ReadByte()); + Assert.Equal(stream.Length, stream.Position); + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void Read_CountLessThanSegmentSize_CanReadAllBytes(TestStreamInitInfo info) + { + var stream = new SegmentReadStream(info.Segments, info.Length); + var count = info.SegmentSize - 1; + + for (var i = 0; i < stream.Length; i+=count) + { + var output = new byte[count]; + var expectedOutput = new byte[count]; + var expectedBytesRead = Math.Min(count, stream.Length - i); + for (var j = 0; j < expectedBytesRead; j++) + { + expectedOutput[j] = (byte)(i + j); + } + Assert.Equal(i, stream.Position); + Assert.Equal(expectedBytesRead, stream.Read(output, 0, count)); + Assert.True(expectedOutput.SequenceEqual(output)); + } + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(0, stream.Read(new byte[count], 0, count)); + Assert.Equal(stream.Length, stream.Position); + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void Read_CountEqualSegmentSize_CanReadAllBytes(TestStreamInitInfo info) + { + var stream = new SegmentReadStream(info.Segments, info.Length); + var count = info.SegmentSize; + + for (var i = 0; i < stream.Length; i += count) + { + var output = new byte[count]; + var expectedOutput = new byte[count]; + var expectedBytesRead = Math.Min(count, stream.Length - i); + for (var j = 0; j < expectedBytesRead; j++) + { + expectedOutput[j] = (byte)(i + j); + } + Assert.Equal(i, stream.Position); + Assert.Equal(expectedBytesRead, stream.Read(output, 0, count)); + Assert.True(expectedOutput.SequenceEqual(output)); + } + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(0, stream.Read(new byte[count], 0, count)); + Assert.Equal(stream.Length, stream.Position); + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void Read_CountGreaterThanSegmentSize_CanReadAllBytes(TestStreamInitInfo info) + { + var stream = new SegmentReadStream(info.Segments, info.Length); + var count = info.SegmentSize + 1; + + for (var i = 0; i < stream.Length; i += count) + { + var output = new byte[count]; + var expectedOutput = new byte[count]; + var expectedBytesRead = Math.Min(count, stream.Length - i); + for (var j = 0; j < expectedBytesRead; j++) + { + expectedOutput[j] = (byte)(i + j); + } + Assert.Equal(i, stream.Position); + Assert.Equal(expectedBytesRead, stream.Read(output, 0, count)); + Assert.True(expectedOutput.SequenceEqual(output)); + } + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(0, stream.Read(new byte[count], 0, count)); + Assert.Equal(stream.Length, stream.Position); + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void CopyToAsync_CopiesAllBytes(TestStreamInitInfo info) + { + var stream = new SegmentReadStream(info.Segments, info.Length); + var writeStream = new SegmentWriteStream(info.SegmentSize); + + stream.CopyTo(writeStream); + + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(stream.Length, writeStream.Length); + var writeSegments = writeStream.GetSegments(); + for (var i = 0; i < info.Segments.Count; i++) + { + Assert.True(writeSegments[i].SequenceEqual(info.Segments[i])); + } + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void CopyToAsync_CopiesFromCurrentPosition(TestStreamInitInfo info) + { + var skippedBytes = info.SegmentSize; + var writeStream = new SegmentWriteStream((int)info.Length); + var stream = new SegmentReadStream(info.Segments, info.Length); + stream.Read(new byte[skippedBytes], 0, skippedBytes); + + stream.CopyTo(writeStream); + + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(stream.Length - skippedBytes, writeStream.Length); + var writeSegments = writeStream.GetSegments(); + + for (var i = skippedBytes; i < info.Length; i++) + { + Assert.Equal(info.Segments[i / info.SegmentSize][i % info.SegmentSize], writeSegments[0][i - skippedBytes]); + } + } + + [Theory] + [MemberData(nameof(TestStreams))] + public void CopyToAsync_CopiesFromStart_AfterReset(TestStreamInitInfo info) + { + var skippedBytes = info.SegmentSize; + var writeStream = new SegmentWriteStream(info.SegmentSize); + var stream = new SegmentReadStream(info.Segments, info.Length); + stream.Read(new byte[skippedBytes], 0, skippedBytes); + + stream.CopyTo(writeStream); + + // Assert bytes read from current location to the end + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(stream.Length - skippedBytes, writeStream.Length); + + // Reset + stream.Position = 0; + writeStream = new SegmentWriteStream(info.SegmentSize); + + stream.CopyTo(writeStream); + + Assert.Equal(stream.Length, stream.Position); + Assert.Equal(stream.Length, writeStream.Length); + var writeSegments = writeStream.GetSegments(); + for (var i = 0; i < info.Segments.Count; i++) + { + Assert.True(writeSegments[i].SequenceEqual(info.Segments[i])); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentWriteStreamTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentWriteStreamTests.cs new file mode 100644 index 0000000000..203b685b8d --- /dev/null +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/SegmentWriteStreamTests.cs @@ -0,0 +1,113 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.AspNetCore.ResponseCaching.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.ResponseCaching.Tests +{ + public class SegmentWriteStreamTests + { + private static byte[] WriteData = new byte[] + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + }; + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void SegmentWriteStream_InvalidSegmentSize_Throws(int segmentSize) + { + Assert.Throws(() => new SegmentWriteStream(segmentSize)); + } + + [Fact] + public void ReadAndSeekOperations_Throws() + { + var stream = new SegmentWriteStream(1); + + Assert.Throws(() => stream.Read(new byte[1], 0, 0)); + Assert.Throws(() => stream.Position = 0); + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void GetSegments_ExtractionDisablesWriting() + { + var stream = new SegmentWriteStream(1); + + Assert.True(stream.CanWrite); + Assert.Equal(0, stream.GetSegments().Count); + Assert.False(stream.CanWrite); + } + + [Theory] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + public void WriteByte_CanWriteAllBytes(int segmentSize) + { + var stream = new SegmentWriteStream(segmentSize); + + foreach (var datum in WriteData) + { + stream.WriteByte(datum); + } + var segments = stream.GetSegments(); + + Assert.Equal(WriteData.Length, stream.Length); + Assert.Equal((WriteData.Length + segmentSize - 1)/ segmentSize, segments.Count); + + for (var i = 0; i < WriteData.Length; i += segmentSize) + { + var expectedSegmentSize = Math.Min(segmentSize, WriteData.Length - i); + var expectedSegment = new byte[expectedSegmentSize]; + for (int j = 0; j < expectedSegmentSize; j++) + { + expectedSegment[j] = (byte)(i + j); + } + var segment = segments[i / segmentSize]; + + Assert.Equal(expectedSegmentSize, segment.Length); + Assert.True(expectedSegment.SequenceEqual(segment)); + } + } + + [Theory] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + public void Write_CanWriteAllBytes(int writeSize) + { + var segmentSize = 5; + var stream = new SegmentWriteStream(segmentSize); + + + for (var i = 0; i < WriteData.Length; i += writeSize) + { + stream.Write(WriteData, i, Math.Min(writeSize, WriteData.Length - i)); + } + var segments = stream.GetSegments(); + + Assert.Equal(WriteData.Length, stream.Length); + Assert.Equal((WriteData.Length + segmentSize - 1) / segmentSize, segments.Count); + + for (var i = 0; i < WriteData.Length; i += segmentSize) + { + var expectedSegmentSize = Math.Min(segmentSize, WriteData.Length - i); + var expectedSegment = new byte[expectedSegmentSize]; + for (int j = 0; j < expectedSegmentSize; j++) + { + expectedSegment[j] = (byte)(i + j); + } + var segment = segments[i / segmentSize]; + + Assert.Equal(expectedSegmentSize, segment.Length); + Assert.True(expectedSegment.SequenceEqual(segment)); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs index d938065892..f6f0ed2db1 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.ResponseCaching.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.ObjectPool; @@ -20,6 +21,12 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests { internal class TestUtils { + static TestUtils() + { + // Force sharding in tests + StreamUtilities.BodySegmentSize = 10; + } + internal static RequestDelegate TestRequestDelegate = async (context) => { var uniqueId = Guid.NewGuid().ToString(); @@ -44,7 +51,7 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests return new ResponseCacheKeyProvider(new DefaultObjectPoolProvider(), Options.Create(options)); } - internal static IWebHostBuilder CreateBuilderWithResponseCache( + internal static IEnumerable CreateBuildersWithResponseCache( Action configureDelegate = null, ResponseCacheOptions options = null, RequestDelegate requestDelegate = null) @@ -62,10 +69,24 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests requestDelegate = TestRequestDelegate; } - return new WebHostBuilder() + // Test with MemoryResponseCacheStore + yield return new WebHostBuilder() .ConfigureServices(services => { - services.AddDistributedResponseCache(); + services.AddMemoryResponseCacheStore(); + }) + .Configure(app => + { + configureDelegate(app); + app.UseResponseCache(options); + app.Run(requestDelegate); + }); + + // Test with DistributedResponseCacheStore + yield return new WebHostBuilder() + .ConfigureServices(services => + { + services.AddDistributedResponseCacheStore(); }) .Configure(app => { @@ -167,11 +188,11 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests internal class TestResponseCacheStore : IResponseCacheStore { - private readonly IDictionary _storage = new Dictionary(); + private readonly IDictionary _storage = new Dictionary(); public int GetCount { get; private set; } public int SetCount { get; private set; } - public Task GetAsync(string key) + public Task GetAsync(string key) { GetCount++; try @@ -180,16 +201,11 @@ namespace Microsoft.AspNetCore.ResponseCaching.Tests } catch { - return Task.FromResult(null); + return Task.FromResult(null); } } - public Task RemoveAsync(string key) - { - return TaskCache.CompletedTask; - } - - public Task SetAsync(string key, object entry, TimeSpan validFor) + public Task SetAsync(string key, IResponseCacheEntry entry, TimeSpan validFor) { SetCount++; _storage[key] = entry;