From 4f7d53f4e745b5d407d91517b24c678fa32fab75 Mon Sep 17 00:00:00 2001 From: Kiran Challa Date: Mon, 16 Apr 2018 16:47:54 -0700 Subject: [PATCH] [Fixes #7658] FileStreamResultExecutor fails to Dispose FileStream --- .../FileStreamResultExecutor.cs | 51 +++++++++++-------- .../FileStreamResultTest.cs | 35 +++++++++++++ 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Infrastructure/FileStreamResultExecutor.cs b/src/Microsoft.AspNetCore.Mvc.Core/Infrastructure/FileStreamResultExecutor.cs index 46e10e036c..6cb300eff1 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Infrastructure/FileStreamResultExecutor.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Infrastructure/FileStreamResultExecutor.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure } /// - public virtual Task ExecuteAsync(ActionContext context, FileStreamResult result) + public virtual async Task ExecuteAsync(ActionContext context, FileStreamResult result) { if (context == null) { @@ -29,31 +29,38 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure throw new ArgumentNullException(nameof(result)); } - Logger.ExecutingFileResult(result); - - long? fileLength = null; - if (result.FileStream.CanSeek) + using (result.FileStream) { - fileLength = result.FileStream.Length; + Logger.ExecutingFileResult(result); + + long? fileLength = null; + if (result.FileStream.CanSeek) + { + fileLength = result.FileStream.Length; + } + + var (range, rangeLength, serveBody) = SetHeadersAndLog( + context, + result, + fileLength, + result.EnableRangeProcessing, + result.LastModified, + result.EntityTag); + + if (!serveBody) + { + return; + } + + await WriteFileAsync(context, result, range, rangeLength); } - - var (range, rangeLength, serveBody) = SetHeadersAndLog( - context, - result, - fileLength, - result.EnableRangeProcessing, - result.LastModified, - result.EntityTag); - - if (!serveBody) - { - return Task.CompletedTask; - } - - return WriteFileAsync(context, result, range, rangeLength); } - protected virtual Task WriteFileAsync(ActionContext context, FileStreamResult result, RangeItemHeaderValue range, long rangeLength) + protected virtual Task WriteFileAsync( + ActionContext context, + FileStreamResult result, + RangeItemHeaderValue range, + long rangeLength) { if (context == null) { diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/FileStreamResultTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/FileStreamResultTest.cs index eea79e044f..081eaf7b81 100644 --- a/test/Microsoft.AspNetCore.Mvc.Core.Test/FileStreamResultTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/FileStreamResultTest.cs @@ -126,6 +126,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(contentRange.ToString(), httpResponse.Headers[HeaderNames.ContentRange]); Assert.Equal(contentLength, httpResponse.ContentLength); Assert.Equal(expectedString, body); + Assert.False(readStream.CanSeek); } [Fact] @@ -174,6 +175,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(contentRange.ToString(), httpResponse.Headers[HeaderNames.ContentRange]); Assert.Equal(5, httpResponse.ContentLength); Assert.Equal("Hello", body); + Assert.False(readStream.CanSeek); } [Fact] @@ -217,6 +219,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(lastModified.ToString("R"), httpResponse.Headers[HeaderNames.LastModified]); Assert.Equal(entityTag.ToString(), httpResponse.Headers[HeaderNames.ETag]); Assert.Equal("Hello World", body); + Assert.False(readStream.CanSeek); } [Fact] @@ -261,6 +264,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(lastModified.ToString("R"), httpResponse.Headers[HeaderNames.LastModified]); Assert.Equal(entityTag.ToString(), httpResponse.Headers[HeaderNames.ETag]); Assert.Equal("Hello World", body); + Assert.False(readStream.CanSeek); } [Theory] @@ -303,6 +307,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(lastModified.ToString("R"), httpResponse.Headers[HeaderNames.LastModified]); Assert.Equal(entityTag.ToString(), httpResponse.Headers[HeaderNames.ETag]); Assert.Equal("Hello World", body); + Assert.False(readStream.CanSeek); } [Theory] @@ -346,6 +351,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal(contentRange.ToString(), httpResponse.Headers[HeaderNames.ContentRange]); Assert.Equal(11, httpResponse.ContentLength); Assert.Empty(body); + Assert.False(readStream.CanSeek); } [Fact] @@ -389,6 +395,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Empty(httpResponse.Headers[HeaderNames.ContentRange]); Assert.NotEmpty(httpResponse.Headers[HeaderNames.LastModified]); Assert.Empty(body); + Assert.False(readStream.CanSeek); } [Fact] @@ -432,6 +439,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Empty(httpResponse.Headers[HeaderNames.ContentRange]); Assert.NotEmpty(httpResponse.Headers[HeaderNames.LastModified]); Assert.Empty(body); + Assert.False(readStream.CanSeek); } [Theory] @@ -480,6 +488,7 @@ namespace Microsoft.AspNetCore.Mvc Assert.Equal("bytes", httpResponse.Headers[HeaderNames.AcceptRanges]); Assert.Equal(contentRange.ToString(), httpResponse.Headers[HeaderNames.ContentRange]); Assert.Empty(body); + Assert.False(readStream.CanSeek); } [Fact] @@ -541,6 +550,7 @@ namespace Microsoft.AspNetCore.Mvc // Assert var outBytes = outStream.ToArray(); Assert.True(originalBytes.SequenceEqual(outBytes)); + Assert.False(originalStream.CanSeek); } [Fact] @@ -570,6 +580,31 @@ namespace Microsoft.AspNetCore.Mvc var outBytes = outStream.ToArray(); Assert.True(originalBytes.SequenceEqual(outBytes)); Assert.Equal(expectedContentType, httpContext.Response.ContentType); + Assert.False(originalStream.CanSeek); + } + + [Fact] + public async Task HeadRequest_DoesNotWriteToBody_AndClosesReadStream() + { + // Arrange + var readStream = new MemoryStream(Encoding.UTF8.GetBytes("Hello, World!")); + + var httpContext = GetHttpContext(); + httpContext.Request.Method = "HEAD"; + var outStream = new MemoryStream(); + httpContext.Response.Body = outStream; + + var actionContext = new ActionContext(httpContext, new RouteData(), new ActionDescriptor()); + + var result = new FileStreamResult(readStream, "text/plain"); + + // Act + await result.ExecuteResultAsync(actionContext); + + // Assert + Assert.False(readStream.CanSeek); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.Equal(0, httpContext.Response.Body.Length); } private static IServiceCollection CreateServices()