[Fixes #852] TestHost: OnStarting and OnCompleted callbacks of response are not being awaited

This commit is contained in:
Kiran Challa 2016-09-12 10:59:33 -07:00
parent 98e35cc6da
commit b6da89f54c
4 changed files with 47 additions and 36 deletions

View File

@ -81,7 +81,7 @@ namespace Microsoft.AspNetCore.TestHost
try
{
await _application.ProcessRequestAsync(state.Context);
state.CompleteResponse();
await state.CompleteResponseAsync();
state.ServerCleanup(exception: null);
}
catch (Exception ex)
@ -165,7 +165,7 @@ namespace Microsoft.AspNetCore.TestHost
}
}
_responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest);
_responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest);
httpContext.Response.Body = _responseStream;
httpContext.Response.StatusCode = 200;
httpContext.RequestAborted = _requestAbortedSource.Token;
@ -187,27 +187,30 @@ namespace Microsoft.AspNetCore.TestHost
_responseStream.Complete();
}
internal void CompleteResponse()
internal async Task CompleteResponseAsync()
{
_pipelineFinished = true;
ReturnResponseMessage();
await ReturnResponseMessageAsync();
_responseStream.Complete();
_responseFeature.FireOnResponseCompleted();
await _responseFeature.FireOnResponseCompletedAsync();
}
internal void ReturnResponseMessage()
internal async Task ReturnResponseMessageAsync()
{
if (!_responseTcs.Task.IsCompleted)
// Check if the response has already started because the TrySetResult below could happen a bit late
// (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this
// method again.
if (!Context.HttpContext.Response.HasStarted)
{
var response = GenerateResponse();
var response = await GenerateResponseAsync();
// Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write.
Task.Factory.StartNew(() => _responseTcs.TrySetResult(response));
var setResult = Task.Factory.StartNew(() => _responseTcs.TrySetResult(response));
}
}
private HttpResponseMessage GenerateResponse()
private async Task<HttpResponseMessage> GenerateResponseAsync()
{
_responseFeature.FireOnSendingHeaders();
await _responseFeature.FireOnSendingHeadersAsync();
var httpContext = Context.HttpContext;
var response = new HttpResponseMessage();

View File

@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.TestHost
{
internal class ResponseFeature : IHttpResponseFeature
{
private Action _responseStarting = () => { };
private Action _responseCompleted = () => { };
private Func<Task> _responseStartingAsync = () => Task.FromResult(true);
private Func<Task> _responseCompletedAsync = () => Task.FromResult(true);
public ResponseFeature()
{
@ -36,33 +36,39 @@ namespace Microsoft.AspNetCore.TestHost
public void OnStarting(Func<object, Task> callback, object state)
{
var prior = _responseStarting;
_responseStarting = () =>
var prior = _responseStartingAsync;
_responseStartingAsync = async () =>
{
callback(state);
prior();
await callback(state);
await prior();
};
}
public void OnCompleted(Func<object, Task> callback, object state)
{
var prior = _responseCompleted;
_responseCompleted = () =>
var prior = _responseCompletedAsync;
_responseCompletedAsync = async () =>
{
callback(state);
prior();
try
{
await callback(state);
}
finally
{
await prior();
}
};
}
public void FireOnSendingHeaders()
public async Task FireOnSendingHeadersAsync()
{
_responseStarting();
await _responseStartingAsync();
HasStarted = true;
}
public void FireOnResponseCompleted()
public Task FireOnResponseCompletedAsync()
{
_responseCompleted();
return _responseCompletedAsync();
}
}
}

View File

@ -24,15 +24,15 @@ namespace Microsoft.AspNetCore.TestHost
private TaskCompletionSource<object> _readWaitingForData;
private object _signalReadLock;
private Action _onFirstWrite;
private Func<Task> _onFirstWriteAsync;
private bool _firstWrite;
private Action _abortRequest;
internal ResponseStream(Action onFirstWrite, Action abortRequest)
internal ResponseStream(Func<Task> onFirstWriteAsync, Action abortRequest)
{
if (onFirstWrite == null)
if (onFirstWriteAsync == null)
{
throw new ArgumentNullException(nameof(onFirstWrite));
throw new ArgumentNullException(nameof(onFirstWriteAsync));
}
if (abortRequest == null)
@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.TestHost
throw new ArgumentNullException(nameof(abortRequest));
}
_onFirstWrite = onFirstWrite;
_onFirstWriteAsync = onFirstWriteAsync;
_firstWrite = true;
_abortRequest = abortRequest;
@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.TestHost
_writeLock.Wait();
try
{
FirstWrite();
FirstWriteAsync().GetAwaiter().GetResult();
}
finally
{
@ -230,13 +230,14 @@ namespace Microsoft.AspNetCore.TestHost
}
// Called under write-lock.
private void FirstWrite()
private Task FirstWriteAsync()
{
if (_firstWrite)
{
_firstWrite = false;
_onFirstWrite();
return _onFirstWriteAsync();
}
return Task.FromResult(true);
}
// Write with count 0 will still trigger OnFirstWrite
@ -248,7 +249,7 @@ namespace Microsoft.AspNetCore.TestHost
_writeLock.Wait();
try
{
FirstWrite();
FirstWriteAsync().GetAwaiter().GetResult();
if (count == 0)
{
return;

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.TestHost
@ -8,7 +9,7 @@ namespace Microsoft.AspNetCore.TestHost
public class ResponseFeatureTests
{
[Fact]
public void StatusCode_DefaultsTo200()
public async Task StatusCode_DefaultsTo200()
{
// Arrange & Act
var responseInformation = new ResponseFeature();
@ -17,7 +18,7 @@ namespace Microsoft.AspNetCore.TestHost
Assert.Equal(200, responseInformation.StatusCode);
Assert.False(responseInformation.HasStarted);
responseInformation.FireOnSendingHeaders();
await responseInformation.FireOnSendingHeadersAsync();
Assert.True(responseInformation.HasStarted);
}