diff --git a/src/Microsoft.Net.Http.Server/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.Net.Http.Server/NativeInterop/UnsafeNativeMethods.cs index cac47b387d..3d005ba6b7 100644 --- a/src/Microsoft.Net.Http.Server/NativeInterop/UnsafeNativeMethods.cs +++ b/src/Microsoft.Net.Http.Server/NativeInterop/UnsafeNativeMethods.cs @@ -193,7 +193,7 @@ namespace Microsoft.Net.Http.Server internal static extern uint HttpSendResponseEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, IntPtr pEntityChunks, out uint pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeHandle pOverlapped, IntPtr pLogData); [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] - internal static extern uint HttpWaitForDisconnect(SafeHandle requestQueueHandle, ulong connectionId, SafeNativeOverlapped pOverlapped); + internal static extern uint HttpWaitForDisconnectEx(SafeHandle requestQueueHandle, ulong connectionId, uint reserved, SafeNativeOverlapped overlapped); [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] internal static extern uint HttpCreateServerSession(HTTPAPI_VERSION version, ulong* serverSessionId, uint reserved); diff --git a/src/Microsoft.Net.Http.Server/WebListener.cs b/src/Microsoft.Net.Http.Server/WebListener.cs index d68eceb01d..e897f05e42 100644 --- a/src/Microsoft.Net.Http.Server/WebListener.cs +++ b/src/Microsoft.Net.Http.Server/WebListener.cs @@ -768,7 +768,7 @@ namespace Microsoft.Net.Http.Server uint statusCode; try { - statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped); + statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnectEx(_requestQueueHandle, connectionId, 0, nativeOverlapped); } catch (Win32Exception exception) { diff --git a/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs b/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs index 23c641dea3..f4abdb333d 100644 --- a/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs +++ b/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs @@ -81,24 +81,25 @@ namespace Microsoft.Net.Http.Server string address; using (var server = Utilities.CreateHttpServer(out address)) { - // Note: System.Net.Sockets does not RST the connection by default, it just FINs. - // Http.Sys's disconnect notice requires a RST. - Task responseTask = SendHungRequestAsync("GET", address); - - var context = await server.GetContextAsync(); - CancellationToken ct = context.DisconnectToken; - Assert.True(ct.CanBeCanceled, "CanBeCanceled"); - Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); - ct.Register(() => canceled.Set()); - - using (Socket socket = await responseTask) + using (var client = new HttpClient()) { - socket.Close(0); // Force a RST - } - Assert.True(canceled.WaitOne(interval), "canceled"); - Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + Task responseTask = client.GetAsync(address); - context.Dispose(); + var context = await server.GetContextAsync(); + CancellationToken ct = context.DisconnectToken; + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + ct.Register(() => canceled.Set()); + + client.CancelPendingRequests(); + + Assert.True(canceled.WaitOne(interval), "canceled"); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + + await Assert.ThrowsAsync(() => responseTask); + + context.Dispose(); + } } } @@ -111,9 +112,7 @@ namespace Microsoft.Net.Http.Server string address; using (var server = Utilities.CreateHttpServer(out address)) { - // Note: System.Net.Sockets does not RST the connection by default, it just FINs. - // Http.Sys's disconnect notice requires a RST. - Task responseTask = SendHungRequestAsync("GET", address); + var responseTask = SendRequestAsync(address); var context = await server.GetContextAsync(); CancellationToken ct = context.DisconnectToken; @@ -124,10 +123,44 @@ namespace Microsoft.Net.Http.Server Assert.True(canceled.WaitOne(interval), "Aborted"); Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); - using (Socket socket = await responseTask) + // HttpClient re-tries the request because it doesn't know if the request was received. + context = await server.GetContextAsync(); + context.Abort(); + + await Assert.ThrowsAsync(() => responseTask); + } + } + + [Fact] + public async Task Server_ConnectionCloseHeader_CancellationTokenFires() + { + TimeSpan interval = TimeSpan.FromSeconds(1); + ManualResetEvent canceled = new ManualResetEvent(false); + + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + Task responseTask = SendRequestAsync(address); + + var context = await server.GetContextAsync(); + CancellationToken ct = context.DisconnectToken; + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + ct.Register(() => canceled.Set()); + + context.Response.Headers["Connection"] = "close"; + + context.Response.ContentLength = 11; + using (var writer = new StreamWriter(context.Response.Body)) { - Assert.Throws(() => socket.Receive(new byte[10])); + writer.Write("Hello World"); } + + Assert.True(canceled.WaitOne(interval), "Disconnected"); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + + string response = await responseTask; + Assert.Equal("Hello World", response); } } @@ -166,48 +199,5 @@ namespace Microsoft.Net.Http.Server return await response.Content.ReadAsStringAsync(); } } - - private async Task SendHungRequestAsync(string method, string address) - { - // Connect with a socket - Uri uri = new Uri(address); - TcpClient client = new TcpClient(); - try - { - await client.ConnectAsync(uri.Host, uri.Port); - NetworkStream stream = client.GetStream(); - - // Send an HTTP GET request - byte[] requestBytes = BuildGetRequest(method, uri); - await stream.WriteAsync(requestBytes, 0, requestBytes.Length); - - // Return the opaque network stream - return client.Client; - } - catch (Exception) - { - client.Close(); - throw; - } - } - - private byte[] BuildGetRequest(string method, Uri uri) - { - StringBuilder builder = new StringBuilder(); - builder.Append(method); - builder.Append(" "); - builder.Append(uri.PathAndQuery); - builder.Append(" HTTP/1.1"); - builder.AppendLine(); - - builder.Append("Host: "); - builder.Append(uri.Host); - builder.Append(':'); - builder.Append(uri.Port); - builder.AppendLine(); - - builder.AppendLine(); - return Encoding.ASCII.GetBytes(builder.ToString()); - } } }