diff --git a/src/Microsoft.Net.Http.Server/Helpers.cs b/src/Microsoft.Net.Http.Server/Helpers.cs index 60dad126ec..a8e0a2e5a9 100644 --- a/src/Microsoft.Net.Http.Server/Helpers.cs +++ b/src/Microsoft.Net.Http.Server/Helpers.cs @@ -33,6 +33,13 @@ namespace Microsoft.Net.Http.Server return Task.FromResult(null); } + internal static Task CancelledTask() + { + TaskCompletionSource tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; + } + internal static ConfiguredTaskAwaitable SupressContext(this Task task) { return task.ConfigureAwait(continueOnCapturedContext: false); diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs index 5a47b0498c..0c3ea2baf4 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs @@ -37,6 +37,8 @@ namespace Microsoft.Net.Http.Server { public sealed class RequestContext : IDisposable { + internal static Action AbortDelegate = Abort; + private WebListener _server; private Request _request; private Response _response; @@ -403,6 +405,12 @@ namespace Microsoft.Net.Http.Server _request.Dispose(); } + private static void Abort(object state) + { + var context = (RequestContext)state; + context.Abort(); + } + // This is only called while processing incoming requests. We don't have to worry about cancelling // any response writes. [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestStream.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestStream.cs index fb93d95eb3..aed48a5383 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestStream.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestStream.cs @@ -319,8 +319,10 @@ namespace Microsoft.Net.Http.Server return Task.FromResult(0); } - // TODO: Needs full cancellation integration - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Helpers.CancelledTask(); + } // TODO: Verbose log parameters RequestStreamAsyncResult asyncResult = null; @@ -349,7 +351,13 @@ namespace Microsoft.Net.Http.Server size = MaxReadSize; } - asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead); + CancellationTokenRegistration cancellationRegistration; + if (cancellationToken.CanBeCanceled) + { + cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext); + } + + asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead, cancellationRegistration); uint bytesReturned; try @@ -449,6 +457,7 @@ namespace Microsoft.Net.Http.Server private TaskCompletionSource _tcs; private RequestStream _requestStream; private AsyncCallback _callback; + private CancellationTokenRegistration _cancellationRegistration; internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback) { @@ -464,6 +473,11 @@ namespace Microsoft.Net.Http.Server } internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead) + : this(requestStream, userState, callback, buffer, offset, dataAlreadyRead, new CancellationTokenRegistration()) + { + } + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead, CancellationTokenRegistration cancellationRegistration) : this(requestStream, userState, callback) { _dataAlreadyRead = dataAlreadyRead; @@ -471,6 +485,7 @@ namespace Microsoft.Net.Http.Server overlapped.AsyncResult = this; _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, buffer)); _pinnedBuffer = (Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset)); + _cancellationRegistration = cancellationRegistration; } internal RequestStream RequestStream @@ -583,6 +598,7 @@ namespace Microsoft.Net.Http.Server { _overlapped.Dispose(); } + _cancellationRegistration.Dispose(); } } diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs index 61825d7d73..ce086c4ca0 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs @@ -135,12 +135,20 @@ namespace Microsoft.Net.Http.Server UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); // TODO: Verbose log - // TODO: Real cancellation - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Helpers.CancelledTask(); + } + + CancellationTokenRegistration cancellationRegistration; + if (cancellationToken.CanBeCanceled) + { + cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext); + } // TODO: Don't add MoreData flag if content-length == 0? flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; - ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false); + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false, cancellationRegistration); try { @@ -492,7 +500,7 @@ namespace Microsoft.Net.Http.Server } } - public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancel) + public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) { if (buffer == null) { @@ -521,14 +529,22 @@ namespace Microsoft.Net.Http.Server } // TODO: Verbose log - // TODO: Real cancelation - cancel.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Helpers.CancelledTask(); + } + + CancellationTokenRegistration cancellationRegistration; + if (cancellationToken.CanBeCanceled) + { + cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext); + } uint statusCode; uint bytesSent = 0; flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; bool sentHeaders = _requestContext.Response.SentHeaders; - ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders, cancellationRegistration); // Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite. UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size)); @@ -597,7 +613,7 @@ namespace Microsoft.Net.Http.Server return asyncResult.Task; } - internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancel) + internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancellationToken) { // It's too expensive to validate the file attributes before opening the file. Open the file and then check the lengths. // This all happens inside of ResponseStreamAsyncResult. @@ -610,9 +626,6 @@ namespace Microsoft.Net.Http.Server throw new ObjectDisposedException(GetType().FullName); } - // TODO: Real cancellation - cancel.ThrowIfCancellationRequested(); - UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); if (size == 0 && _leftToWrite != 0) { @@ -624,12 +637,23 @@ namespace Microsoft.Net.Http.Server } // TODO: Verbose log + if (cancellationToken.IsCancellationRequested) + { + return Helpers.CancelledTask(); + } + + CancellationTokenRegistration cancellationRegistration; + if (cancellationToken.CanBeCanceled) + { + cancellationRegistration = cancellationToken.Register(RequestContext.AbortDelegate, _requestContext); + } + uint statusCode; uint bytesSent = 0; flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; bool sentHeaders = _requestContext.Response.SentHeaders; ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, fileName, offset, size, - _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders, cancellationRegistration); long bytesWritten; if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStreamAsyncResult.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStreamAsyncResult.cs index 883459d132..6c829cddba 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStreamAsyncResult.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStreamAsyncResult.cs @@ -43,6 +43,7 @@ namespace Microsoft.Net.Http.Server private TaskCompletionSource _tcs; private AsyncCallback _callback; private uint _bytesSent; + private CancellationTokenRegistration _cancellationRegistration; internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback) { @@ -50,12 +51,20 @@ namespace Microsoft.Net.Http.Server _tcs = new TaskCompletionSource(userState); _callback = callback; } - internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, byte[] buffer, int offset, int size, bool chunked, bool sentHeaders) + : this(responseStream, userState, callback, buffer, offset, size, chunked, sentHeaders, + new CancellationTokenRegistration()) + { + } + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, + byte[] buffer, int offset, int size, bool chunked, bool sentHeaders, + CancellationTokenRegistration cancellationRegistration) : this(responseStream, userState, callback) { _sentHeaders = sentHeaders; + _cancellationRegistration = cancellationRegistration; Overlapped overlapped = new Overlapped(); overlapped.AsyncResult = this; @@ -121,10 +130,12 @@ namespace Microsoft.Net.Http.Server } internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, - string fileName, long offset, long? size, bool chunked, bool sentHeaders) + string fileName, long offset, long? size, bool chunked, bool sentHeaders, + CancellationTokenRegistration cancellationRegistration) : this(responseStream, userState, callback) { _sentHeaders = sentHeaders; + _cancellationRegistration = cancellationRegistration; Overlapped overlapped = new Overlapped(); overlapped.AsyncResult = this; @@ -449,6 +460,7 @@ namespace Microsoft.Net.Http.Server { _fileStream.Dispose(); } + _cancellationRegistration.Dispose(); } } } diff --git a/test/Microsoft.Net.Http.Server.FunctionalTests/RequestBodyTests.cs b/test/Microsoft.Net.Http.Server.FunctionalTests/RequestBodyTests.cs index eddb4388d4..47dfbcb117 100644 --- a/test/Microsoft.Net.Http.Server.FunctionalTests/RequestBodyTests.cs +++ b/test/Microsoft.Net.Http.Server.FunctionalTests/RequestBodyTests.cs @@ -168,6 +168,156 @@ namespace Microsoft.Net.Http.Server } } + [Fact] + public async Task RequestBody_ReadAsyncAlreadyCancelled_ReturnsCanceledTask() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address, "Hello World"); + + var context = await server.GetContextAsync(); + + byte[] input = new byte[10]; + var cts = new CancellationTokenSource(); + cts.Cancel(); + + Task task = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.True(task.IsCanceled); + + context.Dispose(); + + string response = await responseTask; + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBodyWithCancellationToken_Success() + { + StaggardContent content = new StaggardContent(); + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address, content); + + var context = await server.GetContextAsync(); + byte[] input = new byte[10]; + var cts = new CancellationTokenSource(); + int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + content.Block.Release(); + read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + context.Dispose(); + + string response = await responseTask; + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBodyWithTimeout_Success() + { + StaggardContent content = new StaggardContent(); + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address, content); + + var context = await server.GetContextAsync(); + byte[] input = new byte[10]; + var cts = new CancellationTokenSource(); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + content.Block.Release(); + read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + context.Dispose(); + + string response = await responseTask; + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBodyAndCancel_Canceled() + { + StaggardContent content = new StaggardContent(); + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address, content); + + var context = await server.GetContextAsync(); + byte[] input = new byte[10]; + var cts = new CancellationTokenSource(); + int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + var readTask = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.False(readTask.IsCanceled); + cts.Cancel(); + await Assert.ThrowsAsync(async () => await readTask); + content.Block.Release(); + context.Dispose(); + + await Assert.ThrowsAsync(async () => await responseTask); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBodyAndExpiredTimeout_Canceled() + { + StaggardContent content = new StaggardContent(); + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address, content); + + var context = await server.GetContextAsync(); + byte[] input = new byte[10]; + var cts = new CancellationTokenSource(); + int read = await context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.Equal(5, read); + cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + var readTask = context.Request.Body.ReadAsync(input, 0, input.Length, cts.Token); + Assert.False(readTask.IsCanceled); + await Assert.ThrowsAsync(async () => await readTask); + content.Block.Release(); + context.Dispose(); + + await Assert.ThrowsAsync(async () => await responseTask); + } + } + + // Make sure that using our own disconnect token as a read cancellation token doesn't + // cause recursion problems when it fires and calls Abort. + [Fact] + public async Task RequestBody_ReadAsyncPartialBodyAndDisconnectedClient_Canceled() + { + StaggardContent content = new StaggardContent(); + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + var client = new HttpClient(); + var responseTask = client.PostAsync(address, content); + + var context = await server.GetContextAsync(); + byte[] input = new byte[10]; + int read = await context.Request.Body.ReadAsync(input, 0, input.Length, context.DisconnectToken); + Assert.False(context.DisconnectToken.IsCancellationRequested); + // The client should timeout and disconnect, making this read fail. + var assertTask = Assert.ThrowsAsync(async () => await context.Request.Body.ReadAsync(input, 0, input.Length, context.DisconnectToken)); + client.CancelPendingRequests(); + await assertTask; + content.Block.Release(); + context.Dispose(); + + await Assert.ThrowsAsync(async () => await responseTask); + } + } + private Task SendRequestAsync(string uri, string upload) { return SendRequestAsync(uri, new StringContent(upload)); @@ -177,6 +327,7 @@ namespace Microsoft.Net.Http.Server { using (HttpClient client = new HttpClient()) { + client.Timeout = TimeSpan.FromSeconds(10); HttpResponseMessage response = await client.PostAsync(uri, content); response.EnsureSuccessStatusCode(); return await response.Content.ReadAsStringAsync(); diff --git a/test/Microsoft.Net.Http.Server.FunctionalTests/ResponseBodyTests.cs b/test/Microsoft.Net.Http.Server.FunctionalTests/ResponseBodyTests.cs index bd9ab9bf0f..0a9a476ff7 100644 --- a/test/Microsoft.Net.Http.Server.FunctionalTests/ResponseBodyTests.cs +++ b/test/Microsoft.Net.Http.Server.FunctionalTests/ResponseBodyTests.cs @@ -185,6 +185,94 @@ namespace Microsoft.Net.Http.Server } } + [Fact] + public async Task ResponseBody_WriteAsyncWithActiveCancellationToken_Success() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address); + + var context = await server.GetContextAsync(); + var cts = new CancellationTokenSource(); + // First write sends headers + await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + context.Dispose(); + + HttpResponseMessage response = await responseTask; + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_WriteAsyncWithTimerCancellationToken_Success() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address); + + var context = await server.GetContextAsync(); + var cts = new CancellationTokenSource(); + cts.CancelAfter(TimeSpan.FromSeconds(1)); + // First write sends headers + await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + context.Dispose(); + + HttpResponseMessage response = await responseTask; + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_FirstWriteAsyncWithCancelledCancellationToken_CancelsButDoesNotAbort() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address); + + var context = await server.GetContextAsync(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + // First write sends headers + var writeTask = context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + Assert.True(writeTask.IsCanceled); + context.Dispose(); + + HttpResponseMessage response = await responseTask; + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new byte[0], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_SecondWriteAsyncWithCancelledCancellationToken_CancelsButDoesNotAbort() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address); + + var context = await server.GetContextAsync(); + var cts = new CancellationTokenSource(); + // First write sends headers + await context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + cts.Cancel(); + var writeTask = context.Response.Body.WriteAsync(new byte[10], 0, 10, cts.Token); + Assert.True(writeTask.IsCanceled); + context.Dispose(); + + HttpResponseMessage response = await responseTask; + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new byte[10], await response.Content.ReadAsByteArrayAsync()); + } + } + private async Task SendRequestAsync(string uri) { using (HttpClient client = new HttpClient())