diff --git a/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs b/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs index af01a88619..87454d7530 100644 --- a/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs +++ b/src/Microsoft.AspNetCore.TestHost/ClientHandler.cs @@ -81,7 +81,7 @@ namespace Microsoft.AspNetCore.TestHost try { await _application.ProcessRequestAsync(state.Context); - state.CompleteResponse(); + await state.CompleteResponseAsync(); state.ServerCleanup(exception: null); } catch (Exception ex) @@ -165,7 +165,7 @@ namespace Microsoft.AspNetCore.TestHost } } - _responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest); + _responseStream = new ResponseStream(ReturnResponseMessageAsync, AbortRequest); httpContext.Response.Body = _responseStream; httpContext.Response.StatusCode = 200; httpContext.RequestAborted = _requestAbortedSource.Token; @@ -187,27 +187,30 @@ namespace Microsoft.AspNetCore.TestHost _responseStream.Complete(); } - internal void CompleteResponse() + internal async Task CompleteResponseAsync() { _pipelineFinished = true; - ReturnResponseMessage(); + await ReturnResponseMessageAsync(); _responseStream.Complete(); - _responseFeature.FireOnResponseCompleted(); + await _responseFeature.FireOnResponseCompletedAsync(); } - internal void ReturnResponseMessage() + internal async Task ReturnResponseMessageAsync() { - if (!_responseTcs.Task.IsCompleted) + // Check if the response has already started because the TrySetResult below could happen a bit late + // (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this + // method again. + if (!Context.HttpContext.Response.HasStarted) { - var response = GenerateResponse(); + var response = await GenerateResponseAsync(); // Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write. - Task.Factory.StartNew(() => _responseTcs.TrySetResult(response)); + var setResult = Task.Factory.StartNew(() => _responseTcs.TrySetResult(response)); } } - private HttpResponseMessage GenerateResponse() + private async Task GenerateResponseAsync() { - _responseFeature.FireOnSendingHeaders(); + await _responseFeature.FireOnSendingHeadersAsync(); var httpContext = Context.HttpContext; var response = new HttpResponseMessage(); diff --git a/src/Microsoft.AspNetCore.TestHost/ResponseFeature.cs b/src/Microsoft.AspNetCore.TestHost/ResponseFeature.cs index 873da77a21..73a17f6d33 100644 --- a/src/Microsoft.AspNetCore.TestHost/ResponseFeature.cs +++ b/src/Microsoft.AspNetCore.TestHost/ResponseFeature.cs @@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.TestHost { internal class ResponseFeature : IHttpResponseFeature { - private Action _responseStarting = () => { }; - private Action _responseCompleted = () => { }; + private Func _responseStartingAsync = () => Task.FromResult(true); + private Func _responseCompletedAsync = () => Task.FromResult(true); public ResponseFeature() { @@ -36,33 +36,39 @@ namespace Microsoft.AspNetCore.TestHost public void OnStarting(Func callback, object state) { - var prior = _responseStarting; - _responseStarting = () => + var prior = _responseStartingAsync; + _responseStartingAsync = async () => { - callback(state); - prior(); + await callback(state); + await prior(); }; } public void OnCompleted(Func callback, object state) { - var prior = _responseCompleted; - _responseCompleted = () => + var prior = _responseCompletedAsync; + _responseCompletedAsync = async () => { - callback(state); - prior(); + try + { + await callback(state); + } + finally + { + await prior(); + } }; } - public void FireOnSendingHeaders() + public async Task FireOnSendingHeadersAsync() { - _responseStarting(); + await _responseStartingAsync(); HasStarted = true; } - public void FireOnResponseCompleted() + public Task FireOnResponseCompletedAsync() { - _responseCompleted(); + return _responseCompletedAsync(); } } } diff --git a/src/Microsoft.AspNetCore.TestHost/ResponseStream.cs b/src/Microsoft.AspNetCore.TestHost/ResponseStream.cs index 85ee89e3bb..df3a02fcfc 100644 --- a/src/Microsoft.AspNetCore.TestHost/ResponseStream.cs +++ b/src/Microsoft.AspNetCore.TestHost/ResponseStream.cs @@ -24,15 +24,15 @@ namespace Microsoft.AspNetCore.TestHost private TaskCompletionSource _readWaitingForData; private object _signalReadLock; - private Action _onFirstWrite; + private Func _onFirstWriteAsync; private bool _firstWrite; private Action _abortRequest; - internal ResponseStream(Action onFirstWrite, Action abortRequest) + internal ResponseStream(Func onFirstWriteAsync, Action abortRequest) { - if (onFirstWrite == null) + if (onFirstWriteAsync == null) { - throw new ArgumentNullException(nameof(onFirstWrite)); + throw new ArgumentNullException(nameof(onFirstWriteAsync)); } if (abortRequest == null) @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.TestHost throw new ArgumentNullException(nameof(abortRequest)); } - _onFirstWrite = onFirstWrite; + _onFirstWriteAsync = onFirstWriteAsync; _firstWrite = true; _abortRequest = abortRequest; @@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.TestHost _writeLock.Wait(); try { - FirstWrite(); + FirstWriteAsync().GetAwaiter().GetResult(); } finally { @@ -230,13 +230,14 @@ namespace Microsoft.AspNetCore.TestHost } // Called under write-lock. - private void FirstWrite() + private Task FirstWriteAsync() { if (_firstWrite) { _firstWrite = false; - _onFirstWrite(); + return _onFirstWriteAsync(); } + return Task.FromResult(true); } // Write with count 0 will still trigger OnFirstWrite @@ -248,7 +249,7 @@ namespace Microsoft.AspNetCore.TestHost _writeLock.Wait(); try { - FirstWrite(); + FirstWriteAsync().GetAwaiter().GetResult(); if (count == 0) { return; diff --git a/test/Microsoft.AspNetCore.TestHost.Tests/ResponseFeatureTests.cs b/test/Microsoft.AspNetCore.TestHost.Tests/ResponseFeatureTests.cs index dc8831fc99..9ab936bb81 100644 --- a/test/Microsoft.AspNetCore.TestHost.Tests/ResponseFeatureTests.cs +++ b/test/Microsoft.AspNetCore.TestHost.Tests/ResponseFeatureTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Threading.Tasks; using Xunit; namespace Microsoft.AspNetCore.TestHost @@ -8,7 +9,7 @@ namespace Microsoft.AspNetCore.TestHost public class ResponseFeatureTests { [Fact] - public void StatusCode_DefaultsTo200() + public async Task StatusCode_DefaultsTo200() { // Arrange & Act var responseInformation = new ResponseFeature(); @@ -17,7 +18,7 @@ namespace Microsoft.AspNetCore.TestHost Assert.Equal(200, responseInformation.StatusCode); Assert.False(responseInformation.HasStarted); - responseInformation.FireOnSendingHeaders(); + await responseInformation.FireOnSendingHeadersAsync(); Assert.True(responseInformation.HasStarted); }