aspnetcore/test/Microsoft.Net.Http.Server.F.../ServerTests.cs

214 lines
7.5 KiB
C#

// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Net.Http.Server
{
public class ServerTests
{
[Fact]
public async Task Server_200OK_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
context.Dispose();
string response = await responseTask;
Assert.Equal(string.Empty, response);
}
}
[Fact]
public async Task Server_SendHelloWorld_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
context.Response.ContentLength = 11;
using (var writer = new StreamWriter(context.Response.Body))
{
writer.Write("Hello World");
}
string response = await responseTask;
Assert.Equal("Hello World", response);
}
}
[Fact]
public async Task Server_EchoHelloWorld_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
Task<string> responseTask = SendRequestAsync(address, "Hello World");
var context = await server.GetContextAsync();
string input = new StreamReader(context.Request.Body).ReadToEnd();
Assert.Equal("Hello World", input);
context.Response.ContentLength = 11;
using (var writer = new StreamWriter(context.Response.Body))
{
writer.Write("Hello World");
}
string response = await responseTask;
Assert.Equal("Hello World", response);
}
}
[Fact]
public async Task Server_ClientDisconnects_CallCanceled()
{
TimeSpan interval = TimeSpan.FromSeconds(1);
ManualResetEvent canceled = new ManualResetEvent(false);
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
// Note: System.Net.Sockets does not RST the connection by default, it just FINs.
// Http.Sys's disconnect notice requires a RST.
Task<Socket> responseTask = SendHungRequestAsync("GET", address);
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
Assert.True(ct.CanBeCanceled, "CanBeCanceled");
Assert.False(ct.IsCancellationRequested, "IsCancellationRequested");
ct.Register(() => canceled.Set());
using (Socket socket = await responseTask)
{
socket.Close(0); // Force a RST
}
Assert.True(canceled.WaitOne(interval), "canceled");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
context.Dispose();
}
}
[Fact]
public async Task Server_Abort_CallCanceled()
{
TimeSpan interval = TimeSpan.FromSeconds(1);
ManualResetEvent canceled = new ManualResetEvent(false);
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
// Note: System.Net.Sockets does not RST the connection by default, it just FINs.
// Http.Sys's disconnect notice requires a RST.
Task<Socket> responseTask = SendHungRequestAsync("GET", address);
var context = await server.GetContextAsync();
CancellationToken ct = context.DisconnectToken;
Assert.True(ct.CanBeCanceled, "CanBeCanceled");
Assert.False(ct.IsCancellationRequested, "IsCancellationRequested");
ct.Register(() => canceled.Set());
context.Abort();
Assert.True(canceled.WaitOne(interval), "Aborted");
Assert.True(ct.IsCancellationRequested, "IsCancellationRequested");
using (Socket socket = await responseTask)
{
Assert.Throws<SocketException>(() => socket.Receive(new byte[10]));
}
}
}
[Fact]
public async Task Server_SetQueueLimit_Success()
{
string address;
using (var server = Utilities.CreateHttpServer(out address))
{
server.SetRequestQueueLimit(1001);
Task<string> responseTask = SendRequestAsync(address);
var context = await server.GetContextAsync();
context.Dispose();
string response = await responseTask;
Assert.Equal(string.Empty, response);
}
}
private async Task<string> SendRequestAsync(string uri)
{
ServicePointManager.DefaultConnectionLimit = 100;
using (HttpClient client = new HttpClient())
{
return await client.GetStringAsync(uri);
}
}
private async Task<string> SendRequestAsync(string uri, string upload)
{
using (HttpClient client = new HttpClient())
{
HttpResponseMessage response = await client.PostAsync(uri, new StringContent(upload));
response.EnsureSuccessStatusCode();
return await response.Content.ReadAsStringAsync();
}
}
private async Task<Socket> SendHungRequestAsync(string method, string address)
{
// Connect with a socket
Uri uri = new Uri(address);
TcpClient client = new TcpClient();
try
{
await client.ConnectAsync(uri.Host, uri.Port);
NetworkStream stream = client.GetStream();
// Send an HTTP GET request
byte[] requestBytes = BuildGetRequest(method, uri);
await stream.WriteAsync(requestBytes, 0, requestBytes.Length);
// Return the opaque network stream
return client.Client;
}
catch (Exception)
{
client.Close();
throw;
}
}
private byte[] BuildGetRequest(string method, Uri uri)
{
StringBuilder builder = new StringBuilder();
builder.Append(method);
builder.Append(" ");
builder.Append(uri.PathAndQuery);
builder.Append(" HTTP/1.1");
builder.AppendLine();
builder.Append("Host: ");
builder.Append(uri.Host);
builder.Append(':');
builder.Append(uri.Port);
builder.AppendLine();
builder.AppendLine();
return Encoding.ASCII.GetBytes(builder.ToString());
}
}
}