Fix TestServer from blocking on request stream (#15591)

This commit is contained in:
James Newton-King 2019-11-01 10:48:16 +13:00 committed by GitHub
parent 3ceca46c5b
commit cfea2e91dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 399 additions and 36 deletions

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.Http;
@ -65,8 +66,33 @@ namespace Microsoft.AspNetCore.TestHost
var contextBuilder = new HttpContextBuilder(_application, AllowSynchronousIO, PreserveExecutionContext);
var requestContent = request.Content ?? new StreamContent(Stream.Null);
var body = await requestContent.ReadAsStreamAsync();
contextBuilder.Configure(context =>
// Read content from the request HttpContent into a pipe in a background task. This will allow the request
// delegate to start before the request HttpContent is complete. A background task allows duplex streaming scenarios.
contextBuilder.SendRequestStream(async writer =>
{
if (requestContent is StreamContent)
{
// This is odd but required for backwards compat. If StreamContent is passed in then seek to beginning.
// This is safe because StreamContent.ReadAsStreamAsync doesn't block. It will return the inner stream.
var body = await requestContent.ReadAsStreamAsync();
if (body.CanSeek)
{
// This body may have been consumed before, rewind it.
body.Seek(0, SeekOrigin.Begin);
}
await body.CopyToAsync(writer);
}
else
{
await requestContent.CopyToAsync(writer.AsStream());
}
await writer.CompleteAsync();
});
contextBuilder.Configure((context, reader) =>
{
var req = context.Request;
@ -115,12 +141,7 @@ namespace Microsoft.AspNetCore.TestHost
}
}
if (body.CanSeek)
{
// This body may have been consumed before, rewind it.
body.Seek(0, SeekOrigin.Begin);
}
req.Body = new AsyncStreamWrapper(body, () => contextBuilder.AllowSynchronousIO);
req.Body = new AsyncStreamWrapper(reader.AsStream(), () => contextBuilder.AllowSynchronousIO);
});
var response = new HttpResponseMessage();

View File

@ -26,7 +26,10 @@ namespace Microsoft.AspNetCore.TestHost
private bool _pipelineFinished;
private bool _returningResponse;
private object _testContext;
private Pipe _requestPipe;
private Action<HttpContext> _responseReadCompleteCallback;
private Task _sendRequestStreamTask;
internal HttpContextBuilder(ApplicationWrapper application, bool allowSynchronousIO, bool preserveExecutionContext)
{
@ -41,9 +44,11 @@ namespace Microsoft.AspNetCore.TestHost
request.Protocol = "HTTP/1.1";
request.Method = HttpMethods.Get;
var pipe = new Pipe();
_responseReaderStream = new ResponseBodyReaderStream(pipe, ClientInitiatedAbort, () => _responseReadCompleteCallback?.Invoke(_httpContext));
_responsePipeWriter = new ResponseBodyPipeWriter(pipe, ReturnResponseMessageAsync);
_requestPipe = new Pipe();
var responsePipe = new Pipe();
_responseReaderStream = new ResponseBodyReaderStream(responsePipe, ClientInitiatedAbort, () => _responseReadCompleteCallback?.Invoke(_httpContext));
_responsePipeWriter = new ResponseBodyPipeWriter(responsePipe, ReturnResponseMessageAsync);
_responseFeature.Body = new ResponseBodyWriterStream(_responsePipeWriter, () => AllowSynchronousIO);
_responseFeature.BodyWriter = _responsePipeWriter;
@ -56,14 +61,24 @@ namespace Microsoft.AspNetCore.TestHost
public bool AllowSynchronousIO { get; set; }
internal void Configure(Action<HttpContext> configureContext)
internal void Configure(Action<HttpContext, PipeReader> configureContext)
{
if (configureContext == null)
{
throw new ArgumentNullException(nameof(configureContext));
}
configureContext(_httpContext);
configureContext(_httpContext, _requestPipe.Reader);
}
internal void SendRequestStream(Func<PipeWriter, Task> sendRequestStream)
{
if (sendRequestStream == null)
{
throw new ArgumentNullException(nameof(sendRequestStream));
}
_sendRequestStreamTask = sendRequestStream(_requestPipe.Writer);
}
internal void RegisterResponseReadCompleteCallback(Action<HttpContext> responseReadCompleteCallback)
@ -92,10 +107,10 @@ namespace Microsoft.AspNetCore.TestHost
// since we are now inside the Server's execution context. If it happens outside this cont
// it will be lost when we abandon the execution context.
_testContext = _application.CreateContext(_httpContext.Features);
try
{
await _application.ProcessRequestAsync(_testContext);
await CompleteRequestAsync();
await CompleteResponseAsync();
_application.DisposeContext(_testContext, exception: null);
}
@ -134,8 +149,40 @@ namespace Microsoft.AspNetCore.TestHost
// We don't want to trigger the token for already completed responses.
_requestLifetimeFeature.Cancel();
}
// Writes will still succeed, the app will only get an error if they check the CT.
_responseReaderStream.Abort(new IOException("The client aborted the request."));
// Cancel any pending request async activity when the client aborts a duplex
// streaming scenario by disposing the HttpResponseMessage.
CancelRequestBody();
}
private async Task CompleteRequestAsync()
{
if (!_requestPipe.Reader.TryRead(out var result) || !result.IsCompleted)
{
// If request is still in progress then abort it.
CancelRequestBody();
}
else
{
// Writer was already completed in send request callback.
await _requestPipe.Reader.CompleteAsync();
}
if (_sendRequestStreamTask != null)
{
try
{
// Ensure duplex request is either completely read or has been aborted.
await _sendRequestStreamTask;
}
catch (OperationCanceledException)
{
// Request was canceled, likely because it wasn't read before the request ended.
}
}
}
internal async Task CompleteResponseAsync()
@ -192,6 +239,13 @@ namespace Microsoft.AspNetCore.TestHost
_responseReaderStream.Abort(exception);
_requestLifetimeFeature.Cancel();
_responseTcs.TrySetException(exception);
CancelRequestBody();
}
private void CancelRequestBody()
{
_requestPipe.Writer.CancelPendingFlush();
_requestPipe.Reader.CancelPendingRead();
}
void IHttpResetFeature.Reset(int errorCode)

View File

@ -138,7 +138,7 @@ namespace Microsoft.AspNetCore.TestHost
}
var builder = new HttpContextBuilder(Application, AllowSynchronousIO, PreserveExecutionContext);
builder.Configure(context =>
builder.Configure((context, reader) =>
{
var request = context.Request;
request.Scheme = BaseAddress.Scheme;
@ -154,7 +154,7 @@ namespace Microsoft.AspNetCore.TestHost
}
request.PathBase = pathBase;
});
builder.Configure(configureContext);
builder.Configure((context, reader) => configureContext(context));
// TODO: Wrap the request body if any?
return await builder.SendAsync(cancellationToken).ConfigureAwait(false);
}

View File

@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.TestHost
{
WebSocketFeature webSocketFeature = null;
var contextBuilder = new HttpContextBuilder(_application, AllowSynchronousIO, PreserveExecutionContext);
contextBuilder.Configure(context =>
contextBuilder.Configure((context, reader) =>
{
var request = context.Request;
var scheme = uri.Scheme;

View File

@ -4,6 +4,10 @@
<TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\..\..\Shared\SyncPoint\SyncPoint.cs" Link="SyncPoint.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Microsoft.Extensions.DiagnosticAdapter" />

View File

@ -5,6 +5,7 @@ using System;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
@ -13,6 +14,7 @@ using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
@ -89,17 +91,20 @@ namespace Microsoft.AspNetCore.TestHost
{
// Arrange
RequestDelegate appDelegate = async ctx =>
await ctx.Response.WriteAsync(await new StreamReader(ctx.Request.Body).ReadToEndAsync() + " PUT Response");
{
var content = await new StreamReader(ctx.Request.Body).ReadToEndAsync();
await ctx.Response.WriteAsync(content + " PUT Response");
};
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
// Act
var content = new StringContent("Hello world");
var response = await client.PutAsync("http://localhost:12345", content);
var response = await client.PutAsync("http://localhost:12345", content).WithTimeout();
// Assert
Assert.Equal("Hello world PUT Response", await response.Content.ReadAsStringAsync());
Assert.Equal("Hello world PUT Response", await response.Content.ReadAsStringAsync().WithTimeout());
}
[Fact]
@ -114,10 +119,10 @@ namespace Microsoft.AspNetCore.TestHost
// Act
var content = new StringContent("Hello world");
var response = await client.PostAsync("http://localhost:12345", content);
var response = await client.PostAsync("http://localhost:12345", content).WithTimeout();
// Assert
Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync());
Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync().WithTimeout());
}
[Fact]
@ -162,6 +167,296 @@ namespace Microsoft.AspNetCore.TestHost
}
}
[Fact]
public async Task ClientStreamingWorks()
{
// Arrange
var responseStartedSyncPoint = new SyncPoint();
var requestEndingSyncPoint = new SyncPoint();
var requestStreamSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
// Ensure headers received by client
await responseStartedSyncPoint.WaitToContinue();
await ctx.Response.WriteAsync("STARTED");
// ReadToEndAsync will wait until request body is complete
var requestString = await new StreamReader(ctx.Request.Body).ReadToEndAsync();
await ctx.Response.WriteAsync(requestString + " POST Response");
await requestEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).WithTimeout();
await responseStartedSyncPoint.WaitForSyncPoint().WithTimeout();
responseStartedSyncPoint.Continue();
var responseContent = await response.Content.ReadAsStreamAsync().WithTimeout();
// Assert
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().WithTimeout();
Assert.Equal("STARTED", Encoding.UTF8.GetString(buffer, 0, length));
// Send content and finish request body
await requestStream.WriteAsync(Encoding.UTF8.GetBytes("Hello world")).AsTask().WithTimeout();
await requestStream.FlushAsync().WithTimeout();
requestStreamSyncPoint.Continue();
// Ensure content is received while request is in progress
length = await responseContent.ReadAsync(buffer).AsTask().WithTimeout();
Assert.Equal("Hello world POST Response", Encoding.UTF8.GetString(buffer, 0, length));
// Request is ending
await requestEndingSyncPoint.WaitForSyncPoint().WithTimeout();
requestEndingSyncPoint.Continue();
// No more response content
length = await responseContent.ReadAsync(buffer).AsTask().WithTimeout();
Assert.Equal(0, length);
}
[Fact]
public async Task ClientStreaming_Cancellation()
{
// Arrange
var responseStartedSyncPoint = new SyncPoint();
var responseReadSyncPoint = new SyncPoint();
var responseEndingSyncPoint = new SyncPoint();
var requestStreamSyncPoint = new SyncPoint();
var readCanceled = false;
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
// Ensure headers received by client
await responseStartedSyncPoint.WaitToContinue();
var serverBuffer = new byte[1024];
var serverLength = await ctx.Request.Body.ReadAsync(serverBuffer);
Assert.Equal("SENT", Encoding.UTF8.GetString(serverBuffer, 0, serverLength));
await responseReadSyncPoint.WaitToContinue();
try
{
await ctx.Request.Body.ReadAsync(serverBuffer);
}
catch (OperationCanceledException)
{
readCanceled = true;
}
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).WithTimeout();
await responseStartedSyncPoint.WaitForSyncPoint().WithTimeout();
responseStartedSyncPoint.Continue();
var responseContent = await response.Content.ReadAsStreamAsync().WithTimeout();
// Assert
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
// Write to request
await requestStream.WriteAsync(Encoding.UTF8.GetBytes("SENT")).AsTask().WithTimeout();
await requestStream.FlushAsync().WithTimeout();
await responseReadSyncPoint.WaitForSyncPoint().WithTimeout();
// Cancel request. Disposing response must be used because SendAsync has finished.
response.Dispose();
responseReadSyncPoint.Continue();
await responseEndingSyncPoint.WaitForSyncPoint().WithTimeout();
responseEndingSyncPoint.Continue();
Assert.True(readCanceled);
requestStreamSyncPoint.Continue();
}
[Fact]
public async Task ClientStreaming_ResponseCompletesWithoutReadingRequest()
{
// Arrange
var requestStreamTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var responseEndingSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
await ctx.Response.WriteAsync("POST Response");
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamTcs.Task;
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).WithTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().WithTimeout();
// Assert
// Read response
byte[] buffer = new byte[1024];
var length = await responseContent.ReadAsync(buffer).AsTask().WithTimeout();
Assert.Equal("POST Response", Encoding.UTF8.GetString(buffer, 0, length));
// Send large content and block on back pressure
var writeTask = Task.Run(async () =>
{
try
{
await requestStream.WriteAsync(Encoding.UTF8.GetBytes(new string('!', 1024 * 1024 * 50))).AsTask().WithTimeout();
requestStreamTcs.SetResult(null);
}
catch (Exception ex)
{
requestStreamTcs.SetException(ex);
}
});
responseEndingSyncPoint.Continue();
// No more response content
length = await responseContent.ReadAsync(buffer).AsTask().WithTimeout();
Assert.Equal(0, length);
await writeTask;
}
[Fact]
public async Task ClientStreaming_ServerAbort()
{
// Arrange
var requestStreamSyncPoint = new SyncPoint();
var responseEndingSyncPoint = new SyncPoint();
RequestDelegate appDelegate = async ctx =>
{
// Send headers
await ctx.Response.BodyWriter.FlushAsync();
ctx.Abort();
await responseEndingSyncPoint.WaitToContinue();
};
Stream requestStream = null;
var builder = new WebHostBuilder().Configure(app => app.Run(appDelegate));
var server = new TestServer(builder);
var client = server.CreateClient();
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:12345");
httpRequest.Version = new Version(2, 0);
httpRequest.Content = new PushContent(async stream =>
{
requestStream = stream;
await requestStreamSyncPoint.WaitToContinue();
});
// Act
var response = await client.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).WithTimeout();
var responseContent = await response.Content.ReadAsStreamAsync().WithTimeout();
// Assert
// Ensure server has aborted
await responseEndingSyncPoint.WaitForSyncPoint();
// Ensure request stream has started
await requestStreamSyncPoint.WaitForSyncPoint();
// Send content and finish request body
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(
() => requestStream.WriteAsync(Encoding.UTF8.GetBytes("Hello world")).AsTask(),
"Flush was canceled on underlying PipeWriter.").WithTimeout();
responseEndingSyncPoint.Continue();
requestStreamSyncPoint.Continue();
}
private class PushContent : HttpContent
{
private readonly Func<Stream, Task> _sendContent;
public PushContent(Func<Stream, Task> sendContent)
{
_sendContent = sendContent;
}
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
return _sendContent(stream);
}
protected override bool TryComputeLength(out long length)
{
length = -1;
return false;
}
}
[Fact]
public async Task WebSocketWorks()
{

View File

@ -3,6 +3,7 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing;
namespace Microsoft.AspNetCore.TestHost
{
@ -10,20 +11,8 @@ namespace Microsoft.AspNetCore.TestHost
{
internal static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(15);
internal static Task<T> WithTimeout<T>(this Task<T> task) => task.WithTimeout(DefaultTimeout);
internal static Task<T> WithTimeout<T>(this Task<T> task) => task.TimeoutAfter(DefaultTimeout);
internal static async Task<T> WithTimeout<T>(this Task<T> task, TimeSpan timeout)
{
var completedTask = await Task.WhenAny(task, Task.Delay(timeout));
if (completedTask == task)
{
return await task;
}
else
{
throw new TimeoutException("The task has timed out.");
}
}
internal static Task WithTimeout(this Task task) => task.TimeoutAfter(DefaultTimeout);
}
}