#67 - Fire disconnect notifications even for gracefull disconnects.

This commit is contained in:
Chris Ross 2014-09-24 11:33:47 -07:00
parent 4d2b2a14d5
commit 66144c864e
3 changed files with 56 additions and 66 deletions

View File

@ -193,7 +193,7 @@ namespace Microsoft.Net.Http.Server
internal static extern uint HttpSendResponseEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, IntPtr pEntityChunks, out uint pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeHandle pOverlapped, IntPtr pLogData);
[DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)]
internal static extern uint HttpWaitForDisconnect(SafeHandle requestQueueHandle, ulong connectionId, SafeNativeOverlapped pOverlapped);
internal static extern uint HttpWaitForDisconnectEx(SafeHandle requestQueueHandle, ulong connectionId, uint reserved, SafeNativeOverlapped overlapped);
[DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)]
internal static extern uint HttpCreateServerSession(HTTPAPI_VERSION version, ulong* serverSessionId, uint reserved);

View File

@ -768,7 +768,7 @@ namespace Microsoft.Net.Http.Server
uint statusCode;
try
{
statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped);
statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnectEx(_requestQueueHandle, connectionId, 0, nativeOverlapped);
}
catch (Win32Exception exception)
{

View File

@ -81,24 +81,25 @@ namespace Microsoft.Net.Http.Server
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
// Note: System.Net.Sockets does not RST the connection by default, it just FINs.
// Http.Sys's disconnect notice requires a RST.
Task<Socket> responseTask = SendHungRequestAsync("GET", address);
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
Assert.True(ct.CanBeCanceled, "CanBeCanceled");
Assert.False(ct.IsCancellationRequested, "IsCancellationRequested");
ct.Register(() => canceled.Set());
using (Socket socket = await responseTask)
using (var client = new HttpClient())
{
socket.Close(0); // Force a RST
}
Assert.True(canceled.WaitOne(interval), "canceled");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
Task<HttpResponseMessage> responseTask = client.GetAsync(address);
context.Dispose();
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
Assert.True(ct.CanBeCanceled, "CanBeCanceled");
Assert.False(ct.IsCancellationRequested, "IsCancellationRequested");
ct.Register(() => canceled.Set());
client.CancelPendingRequests();
Assert.True(canceled.WaitOne(interval), "canceled");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
await Assert.ThrowsAsync<TaskCanceledException>(() => responseTask);
context.Dispose();
}
}
}
@ -111,9 +112,7 @@ namespace Microsoft.Net.Http.Server
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
// Note: System.Net.Sockets does not RST the connection by default, it just FINs.
// Http.Sys's disconnect notice requires a RST.
Task<Socket> responseTask = SendHungRequestAsync("GET", address);
var responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
@ -124,10 +123,44 @@ namespace Microsoft.Net.Http.Server
Assert.True(canceled.WaitOne(interval), "Aborted");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
using (Socket socket = await responseTask)
// HttpClient re-tries the request because it doesn't know if the request was received.
context = await server.GetContextAsync();
context.Abort();
await Assert.ThrowsAsync<HttpRequestException>(() => responseTask);
}
}
[Fact]
public async Task Server_ConnectionCloseHeader_CancellationTokenFires()
{
TimeSpan interval = TimeSpan.FromSeconds(1);
ManualResetEvent canceled = new ManualResetEvent(false);
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
Assert.True(ct.CanBeCanceled, "CanBeCanceled");
Assert.False(ct.IsCancellationRequested, "IsCancellationRequested");
ct.Register(() => canceled.Set());
context.Response.Headers["Connection"] = "close";
context.Response.ContentLength = 11;
using (var writer = new StreamWriter(context.Response.Body))
{
Assert.Throws<SocketException>(() => socket.Receive(new byte[10]));
writer.Write("Hello World");
}
Assert.True(canceled.WaitOne(interval), "Disconnected");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
string response = await responseTask;
Assert.Equal("Hello World", response);
}
}
@ -166,48 +199,5 @@ namespace Microsoft.Net.Http.Server
return await response.Content.ReadAsStringAsync();
}
}
private async Task<Socket> SendHungRequestAsync(string method, string address)
{
// Connect with a socket
Uri uri = new Uri(address);
TcpClient client = new TcpClient();
try
{
await client.ConnectAsync(uri.Host, uri.Port);
NetworkStream stream = client.GetStream();
// Send an HTTP GET request
byte[] requestBytes = BuildGetRequest(method, uri);
await stream.WriteAsync(requestBytes, 0, requestBytes.Length);
// Return the opaque network stream
return client.Client;
}
catch (Exception)
{
client.Close();
throw;
}
}
private byte[] BuildGetRequest(string method, Uri uri)
{
StringBuilder builder = new StringBuilder();
builder.Append(method);
builder.Append(" ");
builder.Append(uri.PathAndQuery);
builder.Append(" HTTP/1.1");
builder.AppendLine();
builder.Append("Host: ");
builder.Append(uri.Host);
builder.Append(':');
builder.Append(uri.Port);
builder.AppendLine();
builder.AppendLine();
return Encoding.ASCII.GetBytes(builder.ToString());
}
}
}