#43 - Honor CancellationTokens for Read/Write/Flush/SendFileAsync.

This commit is contained in:
Chris Ross 2014-08-08 14:57:21 -07:00
parent dbf13614da
commit 5de5534982
7 changed files with 323 additions and 17 deletions

View File

@ -33,6 +33,13 @@ namespace Microsoft.Net.Http.Server
return Task.FromResult<object>(null);
}
internal static Task<T> CancelledTask<T>()
{
TaskCompletionSource<T> tcs = new TaskCompletionSource<T>();
tcs.TrySetCanceled();
return tcs.Task;
}
internal static ConfiguredTaskAwaitable SupressContext(this Task task)
{
return task.ConfigureAwait(continueOnCapturedContext: false);

View File

@ -37,6 +37,8 @@ namespace Microsoft.Net.Http.Server
{
public sealed class RequestContext : IDisposable
{
internal static Action<object> AbortDelegate = Abort;
private WebListener _server;
private Request _request;
private Response _response;
@ -403,6 +405,12 @@ namespace Microsoft.Net.Http.Server
_request.Dispose();
}
private static void Abort(object state)
{
var context = (RequestContext)state;
context.Abort();
}
// This is only called while processing incoming requests. We don't have to worry about cancelling
// any response writes.
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification =

View File

@ -319,8 +319,10 @@ namespace Microsoft.Net.Http.Server
return Task.FromResult<int>(0);
}
// TODO: Needs full cancellation integration
cancellationToken.ThrowIfCancellationRequested();
if (cancellationToken.IsCancellationRequested)
{
return Helpers.CancelledTask<int>();
}
// TODO: Verbose log parameters
RequestStreamAsyncResult asyncResult = null;
@ -349,7 +351,13 @@ namespace Microsoft.Net.Http.Server
size = MaxReadSize;
}
asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead);
CancellationTokenRegistration cancellationRegistration;
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext);
}
asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead, cancellationRegistration);
uint bytesReturned;
try
@ -449,6 +457,7 @@ namespace Microsoft.Net.Http.Server
private TaskCompletionSource<int> _tcs;
private RequestStream _requestStream;
private AsyncCallback _callback;
private CancellationTokenRegistration _cancellationRegistration;
internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback)
{
@ -464,6 +473,11 @@ namespace Microsoft.Net.Http.Server
}
internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead)
: this(requestStream, userState, callback, buffer, offset, dataAlreadyRead, new CancellationTokenRegistration())
{
}
internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead, CancellationTokenRegistration cancellationRegistration)
: this(requestStream, userState, callback)
{
_dataAlreadyRead = dataAlreadyRead;
@ -471,6 +485,7 @@ namespace Microsoft.Net.Http.Server
overlapped.AsyncResult = this;
_overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, buffer));
_pinnedBuffer = (Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset));
_cancellationRegistration = cancellationRegistration;
}
internal RequestStream RequestStream
@ -583,6 +598,7 @@ namespace Microsoft.Net.Http.Server
{
_overlapped.Dispose();
}
_cancellationRegistration.Dispose();
}
}

View File

@ -135,12 +135,20 @@ namespace Microsoft.Net.Http.Server
UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite();
// TODO: Verbose log
// TODO: Real cancellation
cancellationToken.ThrowIfCancellationRequested();
if (cancellationToken.IsCancellationRequested)
{
return Helpers.CancelledTask<int>();
}
CancellationTokenRegistration cancellationRegistration;
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext);
}
// TODO: Don't add MoreData flag if content-length == 0?
flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false);
ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false, cancellationRegistration);
try
{
@ -492,7 +500,7 @@ namespace Microsoft.Net.Http.Server
}
}
public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancel)
public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken)
{
if (buffer == null)
{
@ -521,14 +529,22 @@ namespace Microsoft.Net.Http.Server
}
// TODO: Verbose log
// TODO: Real cancelation
cancel.ThrowIfCancellationRequested();
if (cancellationToken.IsCancellationRequested)
{
return Helpers.CancelledTask<int>();
}
CancellationTokenRegistration cancellationRegistration;
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext);
}
uint statusCode;
uint bytesSent = 0;
flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
bool sentHeaders = _requestContext.Response.SentHeaders;
ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders);
ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders, cancellationRegistration);
// Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite.
UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size));
@ -597,7 +613,7 @@ namespace Microsoft.Net.Http.Server
return asyncResult.Task;
}
internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancel)
internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancellationToken)
{
// It's too expensive to validate the file attributes before opening the file. Open the file and then check the lengths.
// This all happens inside of ResponseStreamAsyncResult.
@ -610,9 +626,6 @@ namespace Microsoft.Net.Http.Server
throw new ObjectDisposedException(GetType().FullName);
}
// TODO: Real cancellation
cancel.ThrowIfCancellationRequested();
UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite();
if (size == 0 && _leftToWrite != 0)
{
@ -624,12 +637,23 @@ namespace Microsoft.Net.Http.Server
}
// TODO: Verbose log
if (cancellationToken.IsCancellationRequested)
{
return Helpers.CancelledTask<int>();
}
CancellationTokenRegistration cancellationRegistration;
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext);
}
uint statusCode;
uint bytesSent = 0;
flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
bool sentHeaders = _requestContext.Response.SentHeaders;
ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, fileName, offset, size,
_requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders);
_requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders, cancellationRegistration);
long bytesWritten;
if (_requestContext.Response.BoundaryType == BoundaryType.Chunked)

View File

@ -43,6 +43,7 @@ namespace Microsoft.Net.Http.Server
private TaskCompletionSource<object> _tcs;
private AsyncCallback _callback;
private uint _bytesSent;
private CancellationTokenRegistration _cancellationRegistration;
internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback)
{
@ -50,12 +51,20 @@ namespace Microsoft.Net.Http.Server
_tcs = new TaskCompletionSource<object>(userState);
_callback = callback;
}
internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback,
byte[] buffer, int offset, int size, bool chunked, bool sentHeaders)
: this(responseStream, userState, callback, buffer, offset, size, chunked, sentHeaders,
new CancellationTokenRegistration())
{
}
internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback,
byte[] buffer, int offset, int size, bool chunked, bool sentHeaders,
CancellationTokenRegistration cancellationRegistration)
: this(responseStream, userState, callback)
{
_sentHeaders = sentHeaders;
_cancellationRegistration = cancellationRegistration;
Overlapped overlapped = new Overlapped();
overlapped.AsyncResult = this;
@ -121,10 +130,12 @@ namespace Microsoft.Net.Http.Server
}
internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback,
string fileName, long offset, long? size, bool chunked, bool sentHeaders)
string fileName, long offset, long? size, bool chunked, bool sentHeaders,
CancellationTokenRegistration cancellationRegistration)
: this(responseStream, userState, callback)
{
_sentHeaders = sentHeaders;
_cancellationRegistration = cancellationRegistration;
Overlapped overlapped = new Overlapped();
overlapped.AsyncResult = this;
@ -449,6 +460,7 @@ namespace Microsoft.Net.Http.Server
{
_fileStream.Dispose();
}
_cancellationRegistration.Dispose();
}
}
}

View File

@ -168,6 +168,156 @@ namespace Microsoft.Net.Http.Server
}
}
[Fact]
public async Task RequestBody_ReadAsyncAlreadyCancelled_ReturnsCanceledTask()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, "Hello World");
var context = await server.GetContextAsync();
byte[] input = new byte[10];
var cts = new CancellationTokenSource();
cts.Cancel();
Task<int> task = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.True(task.IsCanceled);
context.Dispose();
string response = await responseTask;
Assert.Equal(string.Empty, response);
}
}
[Fact]
public async Task RequestBody_ReadAsyncPartialBodyWithCancellationToken_Success()
{
StaggardContent content = new StaggardContent();
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, content);
var context = await server.GetContextAsync();
byte[] input = new byte[10];
var cts = new CancellationTokenSource();
int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
content.Block.Release();
read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
context.Dispose();
string response = await responseTask;
Assert.Equal(string.Empty, response);
}
}
[Fact]
public async Task RequestBody_ReadAsyncPartialBodyWithTimeout_Success()
{
StaggardContent content = new StaggardContent();
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, content);
var context = await server.GetContextAsync();
byte[] input = new byte[10];
var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromSeconds(5));
int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
content.Block.Release();
read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
context.Dispose();
string response = await responseTask;
Assert.Equal(string.Empty, response);
}
}
[Fact]
public async Task RequestBody_ReadAsyncPartialBodyAndCancel_Canceled()
{
StaggardContent content = new StaggardContent();
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, content);
var context = await server.GetContextAsync();
byte[] input = new byte[10];
var cts = new CancellationTokenSource();
int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
var readTask = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.False(readTask.IsCanceled);
cts.Cancel();
await Assert.ThrowsAsync<WebListenerException>(async () => await readTask);
content.Block.Release();
context.Dispose();
await Assert.ThrowsAsync<HttpRequestException>(async () => await responseTask);
}
}
[Fact]
public async Task RequestBody_ReadAsyncPartialBodyAndExpiredTimeout_Canceled()
{
StaggardContent content = new StaggardContent();
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, content);
var context = await server.GetContextAsync();
byte[] input = new byte[10];
var cts = new CancellationTokenSource();
int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.Equal(5, read);
cts.CancelAfter(TimeSpan.FromMilliseconds(100));
var readTask = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token);
Assert.False(readTask.IsCanceled);
await Assert.ThrowsAsync<WebListenerException>(async () => await readTask);
content.Block.Release();
context.Dispose();
await Assert.ThrowsAsync<HttpRequestException>(async () => await responseTask);
}
}
// Make sure that using our own disconnect token as a read cancellation token doesn't
// cause recursion problems when it fires and calls Abort.
[Fact]
public async Task RequestBody_ReadAsyncPartialBodyAndDisconnectedClient_Canceled()
{
StaggardContent content = new StaggardContent();
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
var client = new HttpClient();
var responseTask = client.PostAsync(address, content);
var context = await server.GetContextAsync();
byte[] input = new byte[10];
int read = await context.Request.Body.ReadAsync(input, 0, input.Length, context.DisconnectToken);
Assert.False(context.DisconnectToken.IsCancellationRequested);
// The client should timeout and disconnect, making this read fail.
var assertTask = Assert.ThrowsAsync<WebListenerException>(async () => await context.Request.Body.ReadAsync(input, 0, input.Length, context.DisconnectToken));
client.CancelPendingRequests();
await assertTask;
content.Block.Release();
context.Dispose();
await Assert.ThrowsAsync<TaskCanceledException>(async () => await responseTask);
}
}
private Task<string> SendRequestAsync(string uri, string upload)
{
return SendRequestAsync(uri, new StringContent(upload));
@ -177,6 +327,7 @@ namespace Microsoft.Net.Http.Server
{
using (HttpClient client = new HttpClient())
{
client.Timeout = TimeSpan.FromSeconds(10);
HttpResponseMessage response = await client.PostAsync(uri, content);
response.EnsureSuccessStatusCode();
return await response.Content.ReadAsStringAsync();

View File

@ -185,6 +185,94 @@ namespace Microsoft.Net.Http.Server
}
}
[Fact]
public async Task ResponseBody_WriteAsyncWithActiveCancellationToken_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<HttpResponseMessage> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
var cts = new CancellationTokenSource();
// First write sends headers
await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
context.Dispose();
HttpResponseMessage response = await responseTask;
Assert.Equal(200, (int)response.StatusCode);
Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync());
}
}
[Fact]
public async Task ResponseBody_WriteAsyncWithTimerCancellationToken_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<HttpResponseMessage> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromSeconds(1));
// First write sends headers
await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
context.Dispose();
HttpResponseMessage response = await responseTask;
Assert.Equal(200, (int)response.StatusCode);
Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync());
}
}
[Fact]
public async Task ResponseBody_FirstWriteAsyncWithCancelledCancellationToken_CancelsButDoesNotAbort()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<HttpResponseMessage> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
var cts = new CancellationTokenSource();
cts.Cancel();
// First write sends headers
var writeTask = context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
Assert.True(writeTask.IsCanceled);
context.Dispose();
HttpResponseMessage response = await responseTask;
Assert.Equal(200, (int)response.StatusCode);
Assert.Equal(new byte[0], await response.Content.ReadAsByteArrayAsync());
}
}
[Fact]
public async Task ResponseBody_SecondWriteAsyncWithCancelledCancellationToken_CancelsButDoesNotAbort()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<HttpResponseMessage> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
var cts = new CancellationTokenSource();
// First write sends headers
await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
cts.Cancel();
var writeTask = context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token);
Assert.True(writeTask.IsCanceled);
context.Dispose();
HttpResponseMessage response = await responseTask;
Assert.Equal(200, (int)response.StatusCode);
Assert.Equal(new byte[10], await response.Content.ReadAsByteArrayAsync());
}
}
private async Task<HttpResponseMessage> SendRequestAsync(string uri)
{
using (HttpClient client = new HttpClient())