Improve TestServer support for Response.StartAsync (#10189)

This commit is contained in:
James Newton-King 2019-05-13 22:38:26 +12:00 committed by GitHub
parent 208299aa31
commit d5207af367
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 15 deletions

View File

@ -20,10 +20,11 @@ namespace Microsoft.AspNetCore.TestHost
private readonly TaskCompletionSource<HttpContext> _responseTcs = new TaskCompletionSource<HttpContext>(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource<HttpContext> _responseTcs = new TaskCompletionSource<HttpContext>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly ResponseStream _responseStream; private readonly ResponseStream _responseStream;
private readonly ResponseFeature _responseFeature = new ResponseFeature(); private readonly ResponseFeature _responseFeature;
private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature(); private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature();
private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature(); private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature();
private bool _pipelineFinished; private bool _pipelineFinished;
private bool _returningResponse;
private Context _testContext; private Context _testContext;
private Action<HttpContext> _responseReadCompleteCallback; private Action<HttpContext> _responseReadCompleteCallback;
@ -33,6 +34,7 @@ namespace Microsoft.AspNetCore.TestHost
AllowSynchronousIO = allowSynchronousIO; AllowSynchronousIO = allowSynchronousIO;
_preserveExecutionContext = preserveExecutionContext; _preserveExecutionContext = preserveExecutionContext;
_httpContext = new DefaultHttpContext(); _httpContext = new DefaultHttpContext();
_responseFeature = new ResponseFeature(Abort);
var request = _httpContext.Request; var request = _httpContext.Request;
request.Protocol = "HTTP/1.1"; request.Protocol = "HTTP/1.1";
@ -40,6 +42,7 @@ namespace Microsoft.AspNetCore.TestHost
_httpContext.Features.Set<IHttpBodyControlFeature>(this); _httpContext.Features.Set<IHttpBodyControlFeature>(this);
_httpContext.Features.Set<IHttpResponseFeature>(_responseFeature); _httpContext.Features.Set<IHttpResponseFeature>(_responseFeature);
_httpContext.Features.Set<IHttpResponseStartFeature>(_responseFeature);
_httpContext.Features.Set<IHttpRequestLifetimeFeature>(_requestLifetimeFeature); _httpContext.Features.Set<IHttpRequestLifetimeFeature>(_requestLifetimeFeature);
_httpContext.Features.Set<IHttpResponseTrailersFeature>(_responseTrailersFeature); _httpContext.Features.Set<IHttpResponseTrailersFeature>(_responseTrailersFeature);
@ -132,12 +135,13 @@ namespace Microsoft.AspNetCore.TestHost
internal async Task ReturnResponseMessageAsync() internal async Task ReturnResponseMessageAsync()
{ {
// Check if the response has already started because the TrySetResult below could happen a bit late // Check if the response is already returning because the TrySetResult below could happen a bit late
// (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this // (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this
// method again. // method again.
if (!_responseFeature.HasStarted) if (!_returningResponse)
{ {
// Sets HasStarted _returningResponse = true;
try try
{ {
await _responseFeature.FireOnSendingHeadersAsync(); await _responseFeature.FireOnSendingHeadersAsync();

View File

@ -3,21 +3,24 @@
using System; using System;
using System.IO; using System.IO;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.TestHost namespace Microsoft.AspNetCore.TestHost
{ {
internal class ResponseFeature : IHttpResponseFeature internal class ResponseFeature : IHttpResponseFeature, IHttpResponseStartFeature
{ {
private readonly HeaderDictionary _headers = new HeaderDictionary();
private readonly Action<Exception> _abort;
private Func<Task> _responseStartingAsync = () => Task.FromResult(true); private Func<Task> _responseStartingAsync = () => Task.FromResult(true);
private Func<Task> _responseCompletedAsync = () => Task.FromResult(true); private Func<Task> _responseCompletedAsync = () => Task.FromResult(true);
private HeaderDictionary _headers = new HeaderDictionary();
private int _statusCode; private int _statusCode;
private string _reasonPhrase; private string _reasonPhrase;
public ResponseFeature() public ResponseFeature(Action<Exception> abort)
{ {
Headers = _headers; Headers = _headers;
Body = new MemoryStream(); Body = new MemoryStream();
@ -25,6 +28,7 @@ namespace Microsoft.AspNetCore.TestHost
// 200 is the default status code all the way down to the host, so we set it // 200 is the default status code all the way down to the host, so we set it
// here to be consistent with the rest of the hosts when writing tests. // here to be consistent with the rest of the hosts when writing tests.
StatusCode = 200; StatusCode = 200;
_abort = abort;
} }
public int StatusCode public int StatusCode
@ -98,14 +102,36 @@ namespace Microsoft.AspNetCore.TestHost
public async Task FireOnSendingHeadersAsync() public async Task FireOnSendingHeadersAsync()
{ {
await _responseStartingAsync(); if (!HasStarted)
HasStarted = true; {
_headers.IsReadOnly = true; try
{
await _responseStartingAsync();
}
finally
{
HasStarted = true;
_headers.IsReadOnly = true;
}
}
} }
public Task FireOnResponseCompletedAsync() public Task FireOnResponseCompletedAsync()
{ {
return _responseCompletedAsync(); return _responseCompletedAsync();
} }
public async Task StartAsync(CancellationToken token = default)
{
try
{
await FireOnSendingHeadersAsync();
}
catch (Exception ex)
{
_abort(ex);
throw;
}
}
} }
} }

View File

@ -153,6 +153,48 @@ namespace Microsoft.AspNetCore.TestHost
}); });
} }
[Fact]
public async Task ResponseStartAsync()
{
var hasStartedTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var hasAssertedResponseTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
bool? preHasStarted = null;
bool? postHasStarted = null;
var handler = new ClientHandler(PathString.Empty, new DummyApplication(async context =>
{
preHasStarted = context.Response.HasStarted;
await context.Response.StartAsync();
postHasStarted = context.Response.HasStarted;
hasStartedTcs.TrySetResult(null);
await hasAssertedResponseTcs.Task;
}));
var invoker = new HttpMessageInvoker(handler);
var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/");
var responseTask = invoker.SendAsync(message, CancellationToken.None);
// Ensure StartAsync has been called in response
await hasStartedTcs.Task;
// Delay so async thread would have had time to attempt to return response
await Task.Delay(100);
Assert.False(responseTask.IsCompleted, "HttpResponse.StartAsync does not return response");
// Asserted that response return was checked, allow response to finish
hasAssertedResponseTcs.TrySetResult(null);
await responseTask;
Assert.False(preHasStarted);
Assert.True(postHasStarted);
}
[Fact] [Fact]
public async Task ResubmitRequestWorks() public async Task ResubmitRequestWorks()
{ {

View File

@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.TestHost
public async Task StatusCode_DefaultsTo200() public async Task StatusCode_DefaultsTo200()
{ {
// Arrange & Act // Arrange & Act
var responseInformation = new ResponseFeature(); var responseInformation = CreateResponseFeature();
// Assert // Assert
Assert.Equal(200, responseInformation.StatusCode); Assert.Equal(200, responseInformation.StatusCode);
@ -25,11 +25,27 @@ namespace Microsoft.AspNetCore.TestHost
Assert.True(responseInformation.Headers.IsReadOnly); Assert.True(responseInformation.Headers.IsReadOnly);
} }
[Fact]
public async Task StartAsync_StartsResponse()
{
// Arrange & Act
var responseInformation = CreateResponseFeature();
// Assert
Assert.Equal(200, responseInformation.StatusCode);
Assert.False(responseInformation.HasStarted);
await responseInformation.StartAsync();
Assert.True(responseInformation.HasStarted);
Assert.True(responseInformation.Headers.IsReadOnly);
}
[Fact] [Fact]
public void OnStarting_ThrowsWhenHasStarted() public void OnStarting_ThrowsWhenHasStarted()
{ {
// Arrange // Arrange
var responseInformation = new ResponseFeature(); var responseInformation = CreateResponseFeature();
responseInformation.HasStarted = true; responseInformation.HasStarted = true;
// Act & Assert // Act & Assert
@ -45,7 +61,7 @@ namespace Microsoft.AspNetCore.TestHost
[Fact] [Fact]
public void StatusCode_ThrowsWhenHasStarted() public void StatusCode_ThrowsWhenHasStarted()
{ {
var responseInformation = new ResponseFeature(); var responseInformation = CreateResponseFeature();
responseInformation.HasStarted = true; responseInformation.HasStarted = true;
Assert.Throws<InvalidOperationException>(() => responseInformation.StatusCode = 400); Assert.Throws<InvalidOperationException>(() => responseInformation.StatusCode = 400);
@ -55,7 +71,7 @@ namespace Microsoft.AspNetCore.TestHost
[Fact] [Fact]
public void StatusCode_MustBeGreaterThan99() public void StatusCode_MustBeGreaterThan99()
{ {
var responseInformation = new ResponseFeature(); var responseInformation = CreateResponseFeature();
Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 99); Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 99);
Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 0); Assert.Throws<ArgumentOutOfRangeException>(() => responseInformation.StatusCode = 0);
@ -64,5 +80,10 @@ namespace Microsoft.AspNetCore.TestHost
responseInformation.StatusCode = 200; responseInformation.StatusCode = 200;
responseInformation.StatusCode = 1000; responseInformation.StatusCode = 1000;
} }
private ResponseFeature CreateResponseFeature()
{
return new ResponseFeature(ex => { });
}
} }
} }

View File

@ -7,11 +7,13 @@ using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Xunit; using Xunit;
@ -114,6 +116,8 @@ namespace Microsoft.AspNetCore.Diagnostics
// add response buffering // add response buffering
app.Use(async (httpContext, next) => app.Use(async (httpContext, next) =>
{ {
httpContext.Features.Set<IHttpResponseStartFeature>(null);
var response = httpContext.Response; var response = httpContext.Response;
var originalResponseBody = response.Body; var originalResponseBody = response.Body;
var bufferingStream = new MemoryStream(); var bufferingStream = new MemoryStream();