diff --git a/src/Hosting/TestHost/src/ClientHandler.cs b/src/Hosting/TestHost/src/ClientHandler.cs index fb71cc81b4..f9e8547af2 100644 --- a/src/Hosting/TestHost/src/ClientHandler.cs +++ b/src/Hosting/TestHost/src/ClientHandler.cs @@ -119,9 +119,22 @@ namespace Microsoft.AspNetCore.TestHost responseBody = context.Response.Body; }); + var response = new HttpResponseMessage(); + + // Copy trailers to the response message when the response stream is complete + contextBuilder.RegisterResponseReadCompleteCallback(context => + { + var responseTrailersFeature = context.Features.Get(); + + foreach (var trailer in responseTrailersFeature.Trailers) + { + bool success = response.TrailingHeaders.TryAddWithoutValidation(trailer.Key, (IEnumerable)trailer.Value); + Contract.Assert(success, "Bad trailer"); + } + }); + var httpContext = await contextBuilder.SendAsync(cancellationToken); - var response = new HttpResponseMessage(); response.StatusCode = (HttpStatusCode)httpContext.Response.StatusCode; response.ReasonPhrase = httpContext.Features.Get().ReasonPhrase; response.RequestMessage = request; diff --git a/src/Hosting/TestHost/src/HttpContextBuilder.cs b/src/Hosting/TestHost/src/HttpContextBuilder.cs index 9fd2beb547..710ffa3d86 100644 --- a/src/Hosting/TestHost/src/HttpContextBuilder.cs +++ b/src/Hosting/TestHost/src/HttpContextBuilder.cs @@ -17,13 +17,15 @@ namespace Microsoft.AspNetCore.TestHost private readonly IHttpApplication _application; private readonly bool _preserveExecutionContext; private readonly HttpContext _httpContext; - - private TaskCompletionSource _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private ResponseStream _responseStream; - private ResponseFeature _responseFeature = new ResponseFeature(); - private RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature(); + + private readonly TaskCompletionSource _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly ResponseStream _responseStream; + private readonly ResponseFeature _responseFeature = new ResponseFeature(); + private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature(); + private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature(); private bool _pipelineFinished; private Context _testContext; + private Action _responseReadCompleteCallback; internal HttpContextBuilder(IHttpApplication application, bool allowSynchronousIO, bool preserveExecutionContext) { @@ -39,8 +41,9 @@ namespace Microsoft.AspNetCore.TestHost _httpContext.Features.Set(this); _httpContext.Features.Set(_responseFeature); _httpContext.Features.Set(_requestLifetimeFeature); - - _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest, () => AllowSynchronousIO); + _httpContext.Features.Set(_responseTrailersFeature); + + _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest, () => AllowSynchronousIO, () => _responseReadCompleteCallback?.Invoke(_httpContext)); _responseFeature.Body = _responseStream; } @@ -56,6 +59,11 @@ namespace Microsoft.AspNetCore.TestHost configureContext(_httpContext); } + internal void RegisterResponseReadCompleteCallback(Action responseReadCompleteCallback) + { + _responseReadCompleteCallback = responseReadCompleteCallback; + } + /// /// Start processing the request. /// diff --git a/src/Hosting/TestHost/src/ResponseStream.cs b/src/Hosting/TestHost/src/ResponseStream.cs index 7563beb4c3..efcf35f494 100644 --- a/src/Hosting/TestHost/src/ResponseStream.cs +++ b/src/Hosting/TestHost/src/ResponseStream.cs @@ -17,22 +17,24 @@ namespace Microsoft.AspNetCore.TestHost internal class ResponseStream : Stream { private bool _complete; + private bool _readerComplete; private bool _aborted; private Exception _abortException; - private SemaphoreSlim _writeLock; - - private Func _onFirstWriteAsync; private bool _firstWrite; - private Action _abortRequest; - private Func _allowSynchronousIO; - private Pipe _pipe = new Pipe(); + private readonly SemaphoreSlim _writeLock; + private readonly Func _onFirstWriteAsync; + private readonly Action _abortRequest; + private readonly Func _allowSynchronousIO; + private readonly Action _readComplete; + private readonly Pipe _pipe = new Pipe(); - internal ResponseStream(Func onFirstWriteAsync, Action abortRequest, Func allowSynchronousIO) + internal ResponseStream(Func onFirstWriteAsync, Action abortRequest, Func allowSynchronousIO, Action readComplete) { _onFirstWriteAsync = onFirstWriteAsync ?? throw new ArgumentNullException(nameof(onFirstWriteAsync)); _abortRequest = abortRequest ?? throw new ArgumentNullException(nameof(abortRequest)); _allowSynchronousIO = allowSynchronousIO ?? throw new ArgumentNullException(nameof(allowSynchronousIO)); + _readComplete = readComplete; _firstWrite = true; _writeLock = new SemaphoreSlim(1, 1); } @@ -108,6 +110,12 @@ namespace Microsoft.AspNetCore.TestHost { VerifyBuffer(buffer, offset, count, allowEmpty: false); CheckAborted(); + + if (_readerComplete) + { + return 0; + } + var registration = cancellationToken.Register(Cancel); try { @@ -116,6 +124,8 @@ namespace Microsoft.AspNetCore.TestHost if (result.Buffer.IsEmpty && result.IsCompleted) { _pipe.Reader.Complete(); + _readComplete(); + _readerComplete = true; return 0; } diff --git a/src/Hosting/TestHost/src/ResponseTrailersFeature.cs b/src/Hosting/TestHost/src/ResponseTrailersFeature.cs new file mode 100644 index 0000000000..12460d0c01 --- /dev/null +++ b/src/Hosting/TestHost/src/ResponseTrailersFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.TestHost +{ + internal class ResponseTrailersFeature : IHttpResponseTrailersFeature + { + public IHeaderDictionary Trailers { get; set; } = new HeaderDictionary(); + } +} diff --git a/src/Hosting/TestHost/test/ClientHandlerTests.cs b/src/Hosting/TestHost/test/ClientHandlerTests.cs index f97afcd5a6..73d37c0159 100644 --- a/src/Hosting/TestHost/test/ClientHandlerTests.cs +++ b/src/Hosting/TestHost/test/ClientHandlerTests.cs @@ -88,6 +88,71 @@ namespace Microsoft.AspNetCore.TestHost return httpClient.GetAsync("https://example.com/"); } + [Fact] + public async Task ServerTrailersSetOnResponseAfterContentRead() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var handler = new ClientHandler(PathString.Empty, new DummyApplication(async context => + { + context.Response.AppendTrailer("StartTrailer", "Value!"); + + await context.Response.WriteAsync("Hello World"); + await context.Response.Body.FlushAsync(); + + // Pause writing response to ensure trailers are written at the end + await tcs.Task; + + await context.Response.WriteAsync("Bye World"); + await context.Response.Body.FlushAsync(); + + context.Response.AppendTrailer("EndTrailer", "Value!"); + })); + + var invoker = new HttpMessageInvoker(handler); + var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/"); + + var response = await invoker.SendAsync(message, CancellationToken.None); + + Assert.Empty(response.TrailingHeaders); + + var responseBody = await response.Content.ReadAsStreamAsync(); + + int read = await responseBody.ReadAsync(new byte[100], 0, 100); + Assert.Equal(11, read); + + Assert.Empty(response.TrailingHeaders); + + var readTask = responseBody.ReadAsync(new byte[100], 0, 100); + Assert.False(readTask.IsCompleted); + tcs.TrySetResult(null); + + read = await readTask; + Assert.Equal(9, read); + + Assert.Empty(response.TrailingHeaders); + + // Read nothing because we're at the end of the response + read = await responseBody.ReadAsync(new byte[100], 0, 100); + Assert.Equal(0, read); + + // Ensure additional reads after end don't effect trailers + read = await responseBody.ReadAsync(new byte[100], 0, 100); + Assert.Equal(0, read); + + Assert.Collection(response.TrailingHeaders, + kvp => + { + Assert.Equal("StartTrailer", kvp.Key); + Assert.Equal("Value!", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("EndTrailer", kvp.Key); + Assert.Equal("Value!", kvp.Value.Single()); + }); + } + [Fact] public async Task ResubmitRequestWorks() {