diff --git a/src/Microsoft.AspNet.TestHost/ClientHandler.cs b/src/Microsoft.AspNet.TestHost/ClientHandler.cs index 970c232afe..12e03e8980 100644 --- a/src/Microsoft.AspNet.TestHost/ClientHandler.cs +++ b/src/Microsoft.AspNet.TestHost/ClientHandler.cs @@ -54,7 +54,7 @@ namespace Microsoft.AspNet.TestHost [NotNull] HttpRequestMessage request, CancellationToken cancellationToken) { - var state = new RequestState(request, _pathBase, cancellationToken); + var state = new RequestState(request, _pathBase); var requestContent = request.Content ?? new StreamContent(Stream.Null); var body = await requestContent.ReadAsStreamAsync(); if (body.CanSeek) @@ -63,41 +63,44 @@ namespace Microsoft.AspNet.TestHost body.Seek(0, SeekOrigin.Begin); } state.HttpContext.Request.Body = body; - var registration = cancellationToken.Register(state.Abort); + var registration = cancellationToken.Register(state.AbortRequest); // Async offload, don't let the test code block the caller. var offload = Task.Factory.StartNew(async () => - { - try { - await _next(state.HttpContext.Features); - state.CompleteResponse(); - } - catch (Exception ex) - { - state.Abort(ex); - } - finally - { - registration.Dispose(); - state.Dispose(); - } - }); + try + { + await _next(state.HttpContext.Features); + state.CompleteResponse(); + } + catch (Exception ex) + { + state.Abort(ex); + } + finally + { + registration.Dispose(); + } + }); return await state.ResponseTask; } - private class RequestState : IDisposable + private class RequestState { private readonly HttpRequestMessage _request; private TaskCompletionSource _responseTcs; private ResponseStream _responseStream; private ResponseFeature _responseFeature; + private CancellationTokenSource _requestAbortedSource; + private bool _pipelineFinished; - internal RequestState(HttpRequestMessage request, PathString pathBase, CancellationToken cancellationToken) + internal RequestState(HttpRequestMessage request, PathString pathBase) { _request = request; _responseTcs = new TaskCompletionSource(); + _requestAbortedSource = new CancellationTokenSource(); + _pipelineFinished = false; if (request.RequestUri.IsDefaultPort) { @@ -131,7 +134,6 @@ namespace Microsoft.AspNet.TestHost } serverRequest.QueryString = QueryString.FromUriComponent(request.RequestUri); - // TODO: serverRequest.CallCancelled = cancellationToken; foreach (var header in request.Headers) { @@ -146,9 +148,10 @@ namespace Microsoft.AspNet.TestHost } } - _responseStream = new ResponseStream(CompleteResponse); + _responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest); HttpContext.Response.Body = _responseStream; HttpContext.Response.StatusCode = 200; + HttpContext.RequestAborted = _requestAbortedSource.Token; } public HttpContext HttpContext { get; private set; } @@ -158,7 +161,23 @@ namespace Microsoft.AspNet.TestHost get { return _responseTcs.Task; } } + internal void AbortRequest() + { + if (!_pipelineFinished) + { + _requestAbortedSource.Cancel(); + } + _responseStream.Complete(); + } + internal void CompleteResponse() + { + _pipelineFinished = true; + ReturnResponseMessage(); + _responseStream.Complete(); + } + + internal void ReturnResponseMessage() { if (!_responseTcs.Task.IsCompleted) { @@ -171,7 +190,7 @@ namespace Microsoft.AspNet.TestHost [SuppressMessage("Microsoft.Reliability", "CA2000:DisposeObjectsBeforeLosingScope", Justification = "HttpResposneMessage must be returned to the caller.")] - internal HttpResponseMessage GenerateResponse() + private HttpResponseMessage GenerateResponse() { _responseFeature.FireOnSendingHeaders(); @@ -194,22 +213,12 @@ namespace Microsoft.AspNet.TestHost return response; } - internal void Abort() - { - Abort(new OperationCanceledException()); - } - internal void Abort(Exception exception) { + _pipelineFinished = true; _responseStream.Abort(exception); _responseTcs.TrySetException(exception); } - - public void Dispose() - { - _responseStream.Dispose(); - // Do not dispose the request, that will be disposed by the caller. - } } } } diff --git a/src/Microsoft.AspNet.TestHost/ResponseStream.cs b/src/Microsoft.AspNet.TestHost/ResponseStream.cs index 42a8a419c5..7f75c5983f 100644 --- a/src/Microsoft.AspNet.TestHost/ResponseStream.cs +++ b/src/Microsoft.AspNet.TestHost/ResponseStream.cs @@ -16,7 +16,7 @@ namespace Microsoft.AspNet.TestHost // when requested by the client. internal class ResponseStream : Stream { - private bool _disposed; + private bool _complete; private bool _aborted; private Exception _abortException; private ConcurrentQueue _bufferedData; @@ -28,11 +28,13 @@ namespace Microsoft.AspNet.TestHost private Action _onFirstWrite; private bool _firstWrite; + private Action _abortRequest; - internal ResponseStream([NotNull] Action onFirstWrite) + internal ResponseStream([NotNull] Action onFirstWrite, [NotNull] Action abortRequest) { _onFirstWrite = onFirstWrite; _firstWrite = true; + _abortRequest = abortRequest; _readLock = new SemaphoreSlim(1, 1); _writeLock = new SemaphoreSlim(1, 1); @@ -83,7 +85,7 @@ namespace Microsoft.AspNet.TestHost public override void Flush() { - CheckDisposed(); + CheckNotComplete(); _writeLock.Wait(); try @@ -130,7 +132,7 @@ namespace Microsoft.AspNet.TestHost byte[] topBuffer = null; while (!_bufferedData.TryDequeue(out topBuffer)) { - if (_disposed) + if (_complete) { CheckAborted(); // Graceful close @@ -189,7 +191,7 @@ namespace Microsoft.AspNet.TestHost byte[] topBuffer = null; while (!_bufferedData.TryDequeue(out topBuffer)) { - if (_disposed) + if (_complete) { CheckAborted(); // Graceful close @@ -233,7 +235,7 @@ namespace Microsoft.AspNet.TestHost public override void Write(byte[] buffer, int offset, int count) { VerifyBuffer(buffer, offset, count, allowEmpty: true); - CheckDisposed(); + CheckNotComplete(); _writeLock.Wait(); try @@ -317,7 +319,7 @@ namespace Microsoft.AspNet.TestHost { _readWaitingForData = new TaskCompletionSource(); - if (!_bufferedData.IsEmpty || _disposed) + if (!_bufferedData.IsEmpty || _complete) { // Race, data could have arrived before we created the TCS. _readWaitingForData.TrySetResult(null); @@ -337,7 +339,18 @@ namespace Microsoft.AspNet.TestHost Contract.Requires(innerException != null); _aborted = true; _abortException = innerException; - Dispose(); + Complete(); + } + + internal void Complete() + { + // Prevent race with WaitForDataAsync + lock (_signalReadLock) + { + // Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads. + _complete = true; + _readWaitingForData.TrySetResult(null); + } } private void CheckAborted() @@ -354,23 +367,16 @@ namespace Microsoft.AspNet.TestHost { if (disposing) { - // Prevent race with WaitForDataAsync - lock (_signalReadLock) - { - // Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads. - _disposed = true; - _readWaitingForData.TrySetResult(null); - } + _abortRequest(); } - base.Dispose(disposing); } - private void CheckDisposed() + private void CheckNotComplete() { - if (_disposed) + if (_complete) { - throw new ObjectDisposedException(GetType().FullName); + throw new IOException("The request was aborted or the pipeline has finished"); } } } diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs index e207d499fc..ae1dfd19ca 100644 --- a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs +++ b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; using System.Linq; using System.Net.Http; @@ -245,5 +246,67 @@ namespace Microsoft.AspNet.TestHost clientSocket.Dispose(); } + + [Fact] + public async Task ClientDisposalAbortsRequest() + { + // Arrange + TaskCompletionSource tcs = new TaskCompletionSource(); + RequestDelegate appDelegate = async ctx => + { + // Write Headers + await ctx.Response.Body.FlushAsync(); + + var sem = new SemaphoreSlim(0); + try + { + await sem.WaitAsync(ctx.RequestAborted); + } + catch (Exception e) + { + tcs.SetException(e); + } + }; + + // Act + var server = TestServer.Create(app => app.Run(appDelegate)); + var client = server.CreateClient(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345"); + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + // Abort Request + response.Dispose(); + + // Assert + var exception = await Assert.ThrowsAnyAsync(async () => await tcs.Task); + } + + [Fact] + public async Task ClientCancellationAbortsRequest() + { + // Arrange + TaskCompletionSource tcs = new TaskCompletionSource(); + RequestDelegate appDelegate = async ctx => + { + var sem = new SemaphoreSlim(0); + try + { + await sem.WaitAsync(ctx.RequestAborted); + } + catch (Exception e) + { + tcs.SetException(e); + } + }; + + // Act + var server = TestServer.Create(app => app.Run(appDelegate)); + var client = server.CreateClient(); + var cts = new CancellationTokenSource(); + cts.CancelAfter(500); + var response = await client.GetAsync("http://localhost:12345", cts.Token); + + // Assert + var exception = await Assert.ThrowsAnyAsync(async () => await tcs.Task); + } } }