Implement HttpContext.RequestAborted

This commit is contained in:
Master T 2015-09-02 20:36:55 +02:00
parent ee7825ecb8
commit 7dbe5dfbe4
3 changed files with 130 additions and 52 deletions

View File

@ -54,7 +54,7 @@ namespace Microsoft.AspNet.TestHost
[NotNull] HttpRequestMessage request,
CancellationToken cancellationToken)
{
var state = new RequestState(request, _pathBase, cancellationToken);
var state = new RequestState(request, _pathBase);
var requestContent = request.Content ?? new StreamContent(Stream.Null);
var body = await requestContent.ReadAsStreamAsync();
if (body.CanSeek)
@ -63,41 +63,44 @@ namespace Microsoft.AspNet.TestHost
body.Seek(0, SeekOrigin.Begin);
}
state.HttpContext.Request.Body = body;
var registration = cancellationToken.Register(state.Abort);
var registration = cancellationToken.Register(state.AbortRequest);
// Async offload, don't let the test code block the caller.
var offload = Task.Factory.StartNew(async () =>
{
try
{
await _next(state.HttpContext.Features);
state.CompleteResponse();
}
catch (Exception ex)
{
state.Abort(ex);
}
finally
{
registration.Dispose();
state.Dispose();
}
});
try
{
await _next(state.HttpContext.Features);
state.CompleteResponse();
}
catch (Exception ex)
{
state.Abort(ex);
}
finally
{
registration.Dispose();
}
});
return await state.ResponseTask;
}
private class RequestState : IDisposable
private class RequestState
{
private readonly HttpRequestMessage _request;
private TaskCompletionSource<HttpResponseMessage> _responseTcs;
private ResponseStream _responseStream;
private ResponseFeature _responseFeature;
private CancellationTokenSource _requestAbortedSource;
private bool _pipelineFinished;
internal RequestState(HttpRequestMessage request, PathString pathBase, CancellationToken cancellationToken)
internal RequestState(HttpRequestMessage request, PathString pathBase)
{
_request = request;
_responseTcs = new TaskCompletionSource<HttpResponseMessage>();
_requestAbortedSource = new CancellationTokenSource();
_pipelineFinished = false;
if (request.RequestUri.IsDefaultPort)
{
@ -131,7 +134,6 @@ namespace Microsoft.AspNet.TestHost
}
serverRequest.QueryString = QueryString.FromUriComponent(request.RequestUri);
// TODO: serverRequest.CallCancelled = cancellationToken;
foreach (var header in request.Headers)
{
@ -146,9 +148,10 @@ namespace Microsoft.AspNet.TestHost
}
}
_responseStream = new ResponseStream(CompleteResponse);
_responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest);
HttpContext.Response.Body = _responseStream;
HttpContext.Response.StatusCode = 200;
HttpContext.RequestAborted = _requestAbortedSource.Token;
}
public HttpContext HttpContext { get; private set; }
@ -158,7 +161,23 @@ namespace Microsoft.AspNet.TestHost
get { return _responseTcs.Task; }
}
internal void AbortRequest()
{
if (!_pipelineFinished)
{
_requestAbortedSource.Cancel();
}
_responseStream.Complete();
}
internal void CompleteResponse()
{
_pipelineFinished = true;
ReturnResponseMessage();
_responseStream.Complete();
}
internal void ReturnResponseMessage()
{
if (!_responseTcs.Task.IsCompleted)
{
@ -171,7 +190,7 @@ namespace Microsoft.AspNet.TestHost
[SuppressMessage("Microsoft.Reliability", "CA2000:DisposeObjectsBeforeLosingScope",
Justification = "HttpResposneMessage must be returned to the caller.")]
internal HttpResponseMessage GenerateResponse()
private HttpResponseMessage GenerateResponse()
{
_responseFeature.FireOnSendingHeaders();
@ -194,22 +213,12 @@ namespace Microsoft.AspNet.TestHost
return response;
}
internal void Abort()
{
Abort(new OperationCanceledException());
}
internal void Abort(Exception exception)
{
_pipelineFinished = true;
_responseStream.Abort(exception);
_responseTcs.TrySetException(exception);
}
public void Dispose()
{
_responseStream.Dispose();
// Do not dispose the request, that will be disposed by the caller.
}
}
}
}

View File

@ -16,7 +16,7 @@ namespace Microsoft.AspNet.TestHost
// when requested by the client.
internal class ResponseStream : Stream
{
private bool _disposed;
private bool _complete;
private bool _aborted;
private Exception _abortException;
private ConcurrentQueue<byte[]> _bufferedData;
@ -28,11 +28,13 @@ namespace Microsoft.AspNet.TestHost
private Action _onFirstWrite;
private bool _firstWrite;
private Action _abortRequest;
internal ResponseStream([NotNull] Action onFirstWrite)
internal ResponseStream([NotNull] Action onFirstWrite, [NotNull] Action abortRequest)
{
_onFirstWrite = onFirstWrite;
_firstWrite = true;
_abortRequest = abortRequest;
_readLock = new SemaphoreSlim(1, 1);
_writeLock = new SemaphoreSlim(1, 1);
@ -83,7 +85,7 @@ namespace Microsoft.AspNet.TestHost
public override void Flush()
{
CheckDisposed();
CheckNotComplete();
_writeLock.Wait();
try
@ -130,7 +132,7 @@ namespace Microsoft.AspNet.TestHost
byte[] topBuffer = null;
while (!_bufferedData.TryDequeue(out topBuffer))
{
if (_disposed)
if (_complete)
{
CheckAborted();
// Graceful close
@ -189,7 +191,7 @@ namespace Microsoft.AspNet.TestHost
byte[] topBuffer = null;
while (!_bufferedData.TryDequeue(out topBuffer))
{
if (_disposed)
if (_complete)
{
CheckAborted();
// Graceful close
@ -233,7 +235,7 @@ namespace Microsoft.AspNet.TestHost
public override void Write(byte[] buffer, int offset, int count)
{
VerifyBuffer(buffer, offset, count, allowEmpty: true);
CheckDisposed();
CheckNotComplete();
_writeLock.Wait();
try
@ -317,7 +319,7 @@ namespace Microsoft.AspNet.TestHost
{
_readWaitingForData = new TaskCompletionSource<object>();
if (!_bufferedData.IsEmpty || _disposed)
if (!_bufferedData.IsEmpty || _complete)
{
// Race, data could have arrived before we created the TCS.
_readWaitingForData.TrySetResult(null);
@ -337,7 +339,18 @@ namespace Microsoft.AspNet.TestHost
Contract.Requires(innerException != null);
_aborted = true;
_abortException = innerException;
Dispose();
Complete();
}
internal void Complete()
{
// Prevent race with WaitForDataAsync
lock (_signalReadLock)
{
// Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads.
_complete = true;
_readWaitingForData.TrySetResult(null);
}
}
private void CheckAborted()
@ -354,23 +367,16 @@ namespace Microsoft.AspNet.TestHost
{
if (disposing)
{
// Prevent race with WaitForDataAsync
lock (_signalReadLock)
{
// Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads.
_disposed = true;
_readWaitingForData.TrySetResult(null);
}
_abortRequest();
}
base.Dispose(disposing);
}
private void CheckDisposed()
private void CheckNotComplete()
{
if (_disposed)
if (_complete)
{
throw new ObjectDisposedException(GetType().FullName);
throw new IOException("The request was aborted or the pipeline has finished");
}
}
}

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Linq;
using System.Net.Http;
@ -245,5 +246,67 @@ namespace Microsoft.AspNet.TestHost
clientSocket.Dispose();
}
[Fact]
public async Task ClientDisposalAbortsRequest()
{
// Arrange
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
RequestDelegate appDelegate = async ctx =>
{
// Write Headers
await ctx.Response.Body.FlushAsync();
var sem = new SemaphoreSlim(0);
try
{
await sem.WaitAsync(ctx.RequestAborted);
}
catch (Exception e)
{
tcs.SetException(e);
}
};
// Act
var server = TestServer.Create(app => app.Run(appDelegate));
var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:12345");
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
// Abort Request
response.Dispose();
// Assert
var exception = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await tcs.Task);
}
[Fact]
public async Task ClientCancellationAbortsRequest()
{
// Arrange
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
RequestDelegate appDelegate = async ctx =>
{
var sem = new SemaphoreSlim(0);
try
{
await sem.WaitAsync(ctx.RequestAborted);
}
catch (Exception e)
{
tcs.SetException(e);
}
};
// Act
var server = TestServer.Create(app => app.Run(appDelegate));
var client = server.CreateClient();
var cts = new CancellationTokenSource();
cts.CancelAfter(500);
var response = await client.GetAsync("http://localhost:12345", cts.Token);
// Assert
var exception = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await tcs.Task);
}
}
}