Abort request on client FIN (#1139).

This commit is contained in:
Cesar Blum Silveira 2016-11-04 11:10:36 -07:00
parent 51ecbd7949
commit cedbe76f52
20 changed files with 423 additions and 289 deletions

View File

@ -42,6 +42,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "shared", "shared", "{0EF2AC
test\shared\MockSocketOutput.cs = test\shared\MockSocketOutput.cs test\shared\MockSocketOutput.cs = test\shared\MockSocketOutput.cs
test\shared\MockSystemClock.cs = test\shared\MockSystemClock.cs test\shared\MockSystemClock.cs = test\shared\MockSystemClock.cs
test\shared\SocketInputExtensions.cs = test\shared\SocketInputExtensions.cs test\shared\SocketInputExtensions.cs = test\shared\SocketInputExtensions.cs
test\shared\TestApp.cs = test\shared\TestApp.cs
test\shared\TestApplicationErrorLogger.cs = test\shared\TestApplicationErrorLogger.cs test\shared\TestApplicationErrorLogger.cs = test\shared\TestApplicationErrorLogger.cs
test\shared\TestConnection.cs = test\shared\TestConnection.cs test\shared\TestConnection.cs = test\shared\TestConnection.cs
test\shared\TestFrame.cs = test\shared\TestFrame.cs test\shared\TestFrame.cs = test\shared\TestFrame.cs

View File

@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Filter;
using Microsoft.AspNetCore.Server.Kestrel.Filter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Filter.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
@ -200,7 +201,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_frame.Input = _filteredStreamAdapter.SocketInput; _frame.Input = _filteredStreamAdapter.SocketInput;
_frame.Output = _filteredStreamAdapter.SocketOutput; _frame.Output = _filteredStreamAdapter.SocketOutput;
_readInputTask = _filteredStreamAdapter.ReadInputAsync(); // Don't attempt to read input if connection has already closed.
// This can happen if a client opens a connection and immediately closes it.
_readInputTask = _socketClosedTcs.Task.Status == TaskStatus.WaitingForActivation ?
_filteredStreamAdapter.ReadInputAsync() :
TaskCache.CompletedTask;
} }
_frame.PrepareRequest = _filterContext.PrepareRequest; _frame.PrepareRequest = _filterContext.PrepareRequest;
@ -278,7 +283,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
Input.IncomingComplete(readCount, error); Input.IncomingComplete(readCount, error);
if (errorDone) if (!normalRead)
{ {
Abort(error); Abort(error);
} }

View File

@ -564,6 +564,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
} }
else else
{ {
CheckLastWrite();
Output.Write(data); Output.Write(data);
} }
} }
@ -599,6 +600,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
} }
else else
{ {
CheckLastWrite();
return Output.WriteAsync(data, cancellationToken: cancellationToken); return Output.WriteAsync(data, cancellationToken: cancellationToken);
} }
} }
@ -631,6 +633,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
} }
else else
{ {
CheckLastWrite();
await Output.WriteAsync(data, cancellationToken: cancellationToken); await Output.WriteAsync(data, cancellationToken: cancellationToken);
} }
} }
@ -658,6 +661,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_responseBytesWritten += count; _responseBytesWritten += count;
} }
private void CheckLastWrite()
{
var responseHeaders = FrameResponseHeaders;
// Prevent firing request aborted token if this is the last write, to avoid
// aborting the request if the app is still running when the client receives
// the final bytes of the response and gracefully closes the connection.
//
// Called after VerifyAndUpdateWrite(), so _responseBytesWritten has already been updated.
if (responseHeaders != null &&
!responseHeaders.HasTransferEncoding &&
responseHeaders.HasContentLength &&
_responseBytesWritten == responseHeaders.HeaderContentLengthValue.Value)
{
_abortedCts = null;
}
}
protected void VerifyResponseContentLength() protected void VerifyResponseContentLength()
{ {
var responseHeaders = FrameResponseHeaders; var responseHeaders = FrameResponseHeaders;
@ -838,6 +859,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
private async Task WriteAutoChunkSuffixAwaited() private async Task WriteAutoChunkSuffixAwaited()
{ {
// For the same reason we call CheckLastWrite() in Content-Length responses.
_abortedCts = null;
await WriteChunkedResponseSuffix(); await WriteChunkedResponseSuffix();
if (_keepAlive) if (_keepAlive)

View File

@ -103,7 +103,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
Assert.Equal(data.Length, bytesWritten); Assert.Equal(data.Length, bytesWritten);
socket.Shutdown(SocketShutdown.Send);
clientFinishedSendingRequestBody.Set(); clientFinishedSendingRequestBody.Set();
}; };

View File

@ -25,13 +25,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
[InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\n\r\n", 1027)] [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\n\r\n", 1027)]
public async Task ServerAcceptsRequestLineWithinLimit(string request, int limit) public async Task ServerAcceptsRequestLineWithinLimit(string request, int limit)
{ {
var maxRequestLineSize = limit;
using (var server = CreateServer(limit)) using (var server = CreateServer(limit))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendEnd(request); await connection.Send(request);
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -57,8 +55,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendAllTryEnd($"{requestLine}\r\n"); await connection.SendAll(requestLine);
await connection.Receive( await connection.ReceiveForcedEnd(
"HTTP/1.1 414 URI Too Long", "HTTP/1.1 414 URI Too Long",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",

View File

@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n");
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -60,7 +60,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n");
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n");
await connection.ReceiveForcedEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 431 Request Header Fields Too Large", "HTTP/1.1 431 Request Header Fields Too Large",
"Connection: close", "Connection: close",
@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n");
await connection.ReceiveForcedEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 431 Request Header Fields Too Large", "HTTP/1.1 431 Request Header Fields Too Large",
"Connection: close", "Connection: close",

View File

@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking;
using Microsoft.AspNetCore.Testing;
using Microsoft.AspNetCore.Testing.xunit; using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing; using Microsoft.Extensions.Logging.Testing;
@ -420,6 +421,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
} }
[Fact]
public async Task RequestAbortedTokenFiredOnClientFIN()
{
var appStarted = new SemaphoreSlim(0);
var requestAborted = new SemaphoreSlim(0);
var builder = new WebHostBuilder()
.UseKestrel()
.UseUrls($"http://127.0.0.1:0")
.Configure(app => app.Run(async context =>
{
appStarted.Release();
var token = context.RequestAborted;
token.Register(() => requestAborted.Release(2));
await requestAborted.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10));
}));
using (var host = builder.Build())
{
host.Start();
using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort()));
socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n"));
await appStarted.WaitAsync();
socket.Shutdown(SocketShutdown.Send);
await requestAborted.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10));
}
}
}
private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress) private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress)
{ {
var builder = new WebHostBuilder() var builder = new WebHostBuilder()

View File

@ -5,7 +5,6 @@ using System;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Net.Sockets;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -288,26 +287,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
disposedTcs.TrySetResult(c.Response.StatusCode); disposedTcs.TrySetResult(c.Response.StatusCode);
}); });
var hostBuilder = new WebHostBuilder() using (var server = new TestServer(handler, new TestServiceContext(), "http://127.0.0.1:0", mockHttpContextFactory.Object))
.UseKestrel()
.UseUrls("http://127.0.0.1:0")
.ConfigureServices(services => services.AddSingleton<IHttpContextFactory>(mockHttpContextFactory.Object))
.Configure(app =>
{ {
app.Run(handler);
});
using (var host = hostBuilder.Build())
{
host.Start();
if (!sendMalformedRequest) if (!sendMalformedRequest)
{ {
using (var client = new HttpClient()) using (var client = new HttpClient())
{ {
try try
{ {
var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); var response = await client.GetAsync($"http://127.0.0.1:{server.Port}/");
Assert.Equal(expectedClientStatusCode, response.StatusCode); Assert.Equal(expectedClientStatusCode, response.StatusCode);
} }
catch catch
@ -321,14 +309,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
else else
{ {
using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (var connection = new TestConnection(server.Port))
{ {
socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); await connection.Send(
socket.Send(Encoding.ASCII.GetBytes( "POST / HTTP/1.1",
"POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked",
"Transfer-Encoding: chunked\r\n" + "",
"\r\n" + "gg");
"wrong")); await connection.ReceiveForcedEnd(
"HTTP/1.1 400 Bad Request",
"Connection: close",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");
} }
} }
@ -453,11 +447,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
[Fact] [Fact]
public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce() public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce()
{ {
const string response = "hello, world";
var logTcs = new TaskCompletionSource<object>();
var mockKestrelTrace = new Mock<IKestrelTrace>(); var mockKestrelTrace = new Mock<IKestrelTrace>();
mockKestrelTrace
.Setup(trace => trace.ConnectionHeadResponseBodyWrite(It.IsAny<string>(), response.Length))
.Callback<string, long>((connectionId, count) => logTcs.SetResult(null));
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
await httpContext.Response.WriteAsync("hello, world"); await httpContext.Response.WriteAsync(response);
await httpContext.Response.Body.FlushAsync(); await httpContext.Response.Body.FlushAsync();
}, new TestServiceContext { Log = mockKestrelTrace.Object })) }, new TestServiceContext { Log = mockKestrelTrace.Object }))
{ {
@ -472,11 +472,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
"", "",
""); "");
// Wait for message to be logged before disposing the socket.
// Disposing the socket will abort the connection and Frame._requestAborted
// might be 1 by the time ProduceEnd() gets called and the message is logged.
await logTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10));
} }
} }
mockKestrelTrace.Verify(kestrelTrace => mockKestrelTrace.Verify(kestrelTrace =>
kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny<string>(), "hello, world".Length), Times.Once); kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny<string>(), response.Length), Times.Once);
} }
[Fact] [Fact]
@ -533,7 +538,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
$"HTTP/1.1 200 OK", $"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
"Content-Length: 11", "Content-Length: 11",
@ -566,7 +571,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error", $"HTTP/1.1 500 Internal Server Error",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -633,7 +638,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error", $"HTTP/1.1 500 Internal Server Error",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -858,7 +863,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -880,7 +885,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
public async Task AppCanWriteOwnBadRequestResponse() public async Task AppCanWriteOwnBadRequestResponse()
{ {
var expectedResponse = string.Empty; var expectedResponse = string.Empty;
var responseWrittenTcs = new TaskCompletionSource<object>(); var responseWritten = new SemaphoreSlim(0);
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -894,18 +899,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
httpContext.Response.StatusCode = 400; httpContext.Response.StatusCode = 400;
httpContext.Response.ContentLength = ex.Message.Length; httpContext.Response.ContentLength = ex.Message.Length;
await httpContext.Response.WriteAsync(ex.Message); await httpContext.Response.WriteAsync(ex.Message);
responseWrittenTcs.SetResult(null); responseWritten.Release();
} }
}, new TestServiceContext())) }, new TestServiceContext()))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.1", "POST / HTTP/1.1",
"Transfer-Encoding: chunked", "Transfer-Encoding: chunked",
"", "",
"bad"); "wrong");
await responseWrittenTcs.Task; await responseWritten.WaitAsync();
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 400 Bad Request", "HTTP/1.1 400 Bad Request",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -935,7 +940,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -951,7 +956,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"Connection: keep-alive", "Connection: keep-alive",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -982,7 +987,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: keep-alive", "Connection: keep-alive",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -998,7 +1003,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"Connection: keep-alive", "Connection: keep-alive",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: keep-alive", "Connection: keep-alive",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -1036,7 +1041,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"hello, world"); "hello, world");
// Make sure connection was kept open // Make sure connection was kept open
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");

View File

@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd(request); await connection.SendAll(request);
await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue);
} }
} }
@ -88,15 +88,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd(request); await connection.SendAll(request);
await ReceiveBadRequestResponse(connection, "505 HTTP Version Not Supported", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "505 HTTP Version Not Supported", server.Context.DateHeaderValue);
} }
} }
} }
[Theory] [Theory]
// Missing final CRLF
[InlineData("Header-1: value1\r\nHeader-2: value2\r\n")]
// Leading whitespace // Leading whitespace
[InlineData(" Header-1: value1\r\nHeader-2: value2\r\n\r\n")] [InlineData(" Header-1: value1\r\nHeader-2: value2\r\n\r\n")]
[InlineData("\tHeader-1: value1\r\nHeader-2: value2\r\n\r\n")] [InlineData("\tHeader-1: value1\r\nHeader-2: value2\r\n\r\n")]
@ -124,7 +122,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{rawHeaders}"); await connection.SendAll($"GET / HTTP/1.1\r\n{rawHeaders}");
await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue);
} }
} }
@ -137,7 +135,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd( await connection.SendAll(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"H\u00eb\u00e4d\u00ebr: value", "H\u00eb\u00e4d\u00ebr: value",
"", "",
@ -168,7 +166,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd($"GET {path} HTTP/1.1\r\n"); await connection.SendAll($"GET {path} HTTP/1.1\r\n");
await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue);
} }
} }
@ -183,7 +181,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd($"{method} / HTTP/1.1\r\n\r\n"); await connection.Send($"{method} / HTTP/1.1\r\n\r\n");
await ReceiveBadRequestResponse(connection, "411 Length Required", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "411 Length Required", server.Context.DateHeaderValue);
} }
} }
@ -198,7 +196,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd($"{method} / HTTP/1.0\r\n\r\n"); await connection.Send($"{method} / HTTP/1.0\r\n\r\n");
await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue);
} }
} }

View File

@ -68,13 +68,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Transfer-Encoding: chunked", "Transfer-Encoding: chunked",
"", "",
"5", "Hello", "5", "Hello",
"6", " World", "6", " World",
"0", "0",
"",
""); "");
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
@ -94,7 +95,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Transfer-Encoding: chunked", "Transfer-Encoding: chunked",
"Connection: keep-alive", "Connection: keep-alive",
@ -143,7 +144,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.1", "POST / HTTP/1.1",
"Content-Length: 5", "Content-Length: 5",
"", "",
@ -254,8 +255,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd(fullRequest); await connection.Send(fullRequest);
await connection.ReceiveEnd(expectedFullResponse); await connection.ReceiveEnd(expectedFullResponse);
} }
} }
@ -282,7 +282,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd( await connection.SendAll(
"POST / HTTP/1.1", "POST / HTTP/1.1",
$"{transferEncodingHeaderLine}", $"{transferEncodingHeaderLine}",
$"{headerLine}", $"{headerLine}",
@ -322,7 +322,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendAllTryEnd( await connection.SendAll(
"POST / HTTP/1.1", "POST / HTTP/1.1",
$"{transferEncodingHeaderLine}", $"{transferEncodingHeaderLine}",
$"{headerLine}", $"{headerLine}",
@ -423,8 +423,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd(fullRequest); await connection.Send(fullRequest);
await connection.ReceiveEnd(expectedFullResponse); await connection.ReceiveEnd(expectedFullResponse);
} }
} }

View File

@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.0", "GET / HTTP/1.0",
"Connection: keep-alive", "Connection: keep-alive",
"", "",
@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"Connection: close", "Connection: close",
"", "",
@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -172,7 +172,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -210,7 +210,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -246,7 +246,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -264,7 +264,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[Theory] [Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterWrite(TestServiceContext testContext) public async Task ConnectionClosedIfExceptionThrownAfterWrite(TestServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -295,7 +295,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[Theory] [Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(TestServiceContext testContext) public async Task ConnectionClosedIfExceptionThrownAfterZeroLengthWrite(TestServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -342,7 +342,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -383,7 +383,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");

View File

@ -3,32 +3,17 @@
using System; using System;
using System.IO; using System.IO;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Filter; using Microsoft.AspNetCore.Server.Kestrel.Filter;
using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Internal;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.Server.KestrelTests namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
public class ConnectionFilterTests public class ConnectionFilterTests
{ {
private async Task App(HttpContext httpContext)
{
var request = httpContext.Request;
var response = httpContext.Response;
while (true)
{
var buffer = new byte[8192];
var count = await request.Body.ReadAsync(buffer, 0, buffer.Length);
if (count == 0)
{
break;
}
await response.Body.WriteAsync(buffer, 0, count);
}
}
[Fact] [Fact]
public async Task CanReadAndWriteWithRewritingConnectionFilter() public async Task CanReadAndWriteWithRewritingConnectionFilter()
{ {
@ -37,12 +22,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
var sendString = "POST / HTTP/1.0\r\nContent-Length: 12\r\n\r\nHello World?"; var sendString = "POST / HTTP/1.0\r\nContent-Length: 12\r\n\r\nHello World?";
using (var server = new TestServer(App, serviceContext)) using (var server = new TestServer(TestApp.EchoApp, serviceContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
// "?" changes to "!" // "?" changes to "!"
await connection.SendEnd(sendString); await connection.Send(sendString);
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: close", "Connection: close",
@ -60,11 +45,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
var serviceContext = new TestServiceContext(new AsyncConnectionFilter()); var serviceContext = new TestServiceContext(new AsyncConnectionFilter());
using (var server = new TestServer(App, serviceContext)) using (var server = new TestServer(TestApp.EchoApp, serviceContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Content-Length: 12", "Content-Length: 12",
"", "",
@ -84,7 +69,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
var serviceContext = new TestServiceContext(new ThrowingConnectionFilter()); var serviceContext = new TestServiceContext(new ThrowingConnectionFilter());
using (var server = new TestServer(App, serviceContext)) using (var server = new TestServer(TestApp.EchoApp, serviceContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
@ -108,15 +93,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
private class RewritingConnectionFilter : IConnectionFilter private class RewritingConnectionFilter : IConnectionFilter
{ {
private static Task _empty = Task.FromResult<object>(null);
private RewritingStream _rewritingStream; private RewritingStream _rewritingStream;
public Task OnConnectionAsync(ConnectionFilterContext context) public Task OnConnectionAsync(ConnectionFilterContext context)
{ {
_rewritingStream = new RewritingStream(context.Connection); _rewritingStream = new RewritingStream(context.Connection);
context.Connection = _rewritingStream; context.Connection = _rewritingStream;
return _empty; return TaskCache.CompletedTask;
} }
public int BytesRead => _rewritingStream.BytesRead; public int BytesRead => _rewritingStream.BytesRead;
@ -189,6 +172,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
return actual; return actual;
} }
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var actual = await _innerStream.ReadAsync(buffer, offset, count);
BytesRead += actual;
return actual;
}
public override long Seek(long offset, SeekOrigin origin) public override long Seek(long offset, SeekOrigin origin)
{ {
return _innerStream.Seek(offset, origin); return _innerStream.Seek(offset, origin);
@ -211,6 +203,19 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
_innerStream.Write(buffer, offset, count); _innerStream.Write(buffer, offset, count);
} }
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
for (int i = 0; i < buffer.Length; i++)
{
if (buffer[i] == '?')
{
buffer[i] = (byte)'!';
}
}
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}
} }
} }
} }

View File

@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.0", "GET / HTTP/1.0",

View File

@ -42,39 +42,6 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
private async Task App(HttpContext httpContext)
{
var request = httpContext.Request;
var response = httpContext.Response;
while (true)
{
var buffer = new byte[8192];
var count = await request.Body.ReadAsync(buffer, 0, buffer.Length);
if (count == 0)
{
break;
}
await response.Body.WriteAsync(buffer, 0, count);
}
}
private async Task AppChunked(HttpContext httpContext)
{
var request = httpContext.Request;
var response = httpContext.Response;
var data = new MemoryStream();
await request.Body.CopyToAsync(data);
var bytes = data.ToArray();
response.Headers["Content-Length"] = bytes.Length.ToString();
await response.Body.WriteAsync(bytes, 0, bytes.Length);
}
private Task EmptyApp(HttpContext httpContext)
{
return Task.FromResult<object>(null);
}
[Theory] [Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public void EngineCanStartAndStop(TestServiceContext testContext) public void EngineCanStartAndStop(TestServiceContext testContext)
@ -88,7 +55,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public void ListenerCanCreateAndDispose(TestServiceContext testContext) public void ListenerCanCreateAndDispose(TestServiceContext testContext)
{ {
testContext.App = App; testContext.App = TestApp.EchoApp;
var engine = new KestrelEngine(testContext); var engine = new KestrelEngine(testContext);
engine.Start(1); engine.Start(1);
var address = ServerAddress.FromUrl("http://127.0.0.1:0/"); var address = ServerAddress.FromUrl("http://127.0.0.1:0/");
@ -101,22 +68,22 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public void ConnectionCanReadAndWrite(TestServiceContext testContext) public void ConnectionCanReadAndWrite(TestServiceContext testContext)
{ {
testContext.App = App; testContext.App = TestApp.EchoApp;
var engine = new KestrelEngine(testContext); var engine = new KestrelEngine(testContext);
engine.Start(1); engine.Start(1);
var address = ServerAddress.FromUrl("http://127.0.0.1:0/"); var address = ServerAddress.FromUrl("http://127.0.0.1:0/");
var started = engine.CreateServer(address); var started = engine.CreateServer(address);
var socket = TestConnection.CreateConnectedLoopbackSocket(address.Port); var socket = TestConnection.CreateConnectedLoopbackSocket(address.Port);
socket.Send(Encoding.ASCII.GetBytes("POST / HTTP/1.0\r\nContent-Length: 11\r\n\r\nHello World")); var data = "Hello World";
socket.Shutdown(SocketShutdown.Send); socket.Send(Encoding.ASCII.GetBytes($"POST / HTTP/1.0\r\nContent-Length: 11\r\n\r\n{data}"));
var buffer = new byte[8192]; var buffer = new byte[data.Length];
while (true) var read = 0;
while (read < data.Length)
{ {
var length = socket.Receive(buffer); read += socket.Receive(buffer, read, buffer.Length - read, SocketFlags.None);
if (length == 0) { break; }
var text = Encoding.ASCII.GetString(buffer, 0, length);
} }
socket.Dispose();
started.Dispose(); started.Dispose();
engine.Dispose(); engine.Dispose();
} }
@ -125,11 +92,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http10RequestReceivesHttp11Response(TestServiceContext testContext) public async Task Http10RequestReceivesHttp11Response(TestServiceContext testContext)
{ {
using (var server = new TestServer(App, testContext)) using (var server = new TestServer(TestApp.EchoApp, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Content-Length: 11", "Content-Length: 11",
"", "",
@ -148,11 +115,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http11(TestServiceContext testContext) public async Task Http11(TestServiceContext testContext)
{ {
using (var server = new TestServer(AppChunked, testContext)) using (var server = new TestServer(TestApp.EchoAppChunked, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.1", "GET / HTTP/1.1",
@ -243,7 +210,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Enumerable.Repeat(response, loopCount) Enumerable.Repeat(response, loopCount)
.Concat(new[] { lastResponse }); .Concat(new[] { lastResponse });
await connection.SendEnd(requestData.ToArray()); await connection.Send(requestData.ToArray());
await connection.ReceiveEnd(responseData.ToArray()); await connection.ReceiveEnd(responseData.ToArray());
} }
@ -258,11 +225,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http10ContentLength(TestServiceContext testContext) public async Task Http10ContentLength(TestServiceContext testContext)
{ {
using (var server = new TestServer(App, testContext)) using (var server = new TestServer(TestApp.EchoApp, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Content-Length: 11", "Content-Length: 11",
"", "",
@ -281,11 +248,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http10KeepAlive(TestServiceContext testContext) public async Task Http10KeepAlive(TestServiceContext testContext)
{ {
using (var server = new TestServer(AppChunked, testContext)) using (var server = new TestServer(TestApp.EchoAppChunked, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.0", "GET / HTTP/1.0",
"Connection: keep-alive", "Connection: keep-alive",
"", "",
@ -314,11 +281,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http10KeepAliveNotUsedIfResponseContentLengthNotSet(TestServiceContext testContext) public async Task Http10KeepAliveNotUsedIfResponseContentLengthNotSet(TestServiceContext testContext)
{ {
using (var server = new TestServer(App, testContext)) using (var server = new TestServer(TestApp.EchoApp, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.0", "GET / HTTP/1.0",
"Connection: keep-alive", "Connection: keep-alive",
"", "",
@ -347,11 +314,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Http10KeepAliveContentLength(TestServiceContext testContext) public async Task Http10KeepAliveContentLength(TestServiceContext testContext)
{ {
using (var server = new TestServer(AppChunked, testContext)) using (var server = new TestServer(TestApp.EchoAppChunked, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.0", "POST / HTTP/1.0",
"Content-Length: 11", "Content-Length: 11",
"Connection: keep-alive", "Connection: keep-alive",
@ -382,7 +349,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task Expect100ContinueForBody(TestServiceContext testContext) public async Task Expect100ContinueForBody(TestServiceContext testContext)
{ {
using (var server = new TestServer(AppChunked, testContext)) using (var server = new TestServer(TestApp.EchoAppChunked, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
@ -392,8 +359,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"Connection: close", "Connection: close",
"Content-Length: 11", "Content-Length: 11",
"\r\n"); "\r\n");
await connection.Receive("HTTP/1.1 100 Continue", "\r\n"); await connection.Receive(
await connection.SendEnd("Hello World"); "HTTP/1.1 100 Continue",
"",
"");
await connection.Send("Hello World");
await connection.Receive( await connection.Receive(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: close", "Connection: close",
@ -409,7 +379,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task DisconnectingClient(TestServiceContext testContext) public async Task DisconnectingClient(TestServiceContext testContext)
{ {
using (var server = new TestServer(App, testContext)) using (var server = new TestServer(TestApp.EchoApp, testContext))
{ {
var socket = TestConnection.CreateConnectedLoopbackSocket(server.Port); var socket = TestConnection.CreateConnectedLoopbackSocket(server.Port);
await Task.Delay(200); await Task.Delay(200);
@ -418,15 +388,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
await Task.Delay(200); await Task.Delay(200);
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.0", "GET / HTTP/1.0",
"\r\n"); "",
"");
await connection.ReceiveEnd( await connection.ReceiveEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
"Connection: close", "Connection: close",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"\r\n"); "",
"");
} }
} }
} }
@ -435,11 +407,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ZeroContentLengthSetAutomaticallyAfterNoWrites(TestServiceContext testContext) public async Task ZeroContentLengthSetAutomaticallyAfterNoWrites(TestServiceContext testContext)
{ {
using (var server = new TestServer(EmptyApp, testContext)) using (var server = new TestServer(TestApp.EmptyApp, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.0", "GET / HTTP/1.0",
@ -472,7 +444,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"Connection: close", "Connection: close",
"", "",
@ -488,7 +460,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.0", "GET / HTTP/1.0",
"", "",
""); "");
@ -507,11 +479,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ZeroContentLengthNotSetAutomaticallyForHeadRequests(TestServiceContext testContext) public async Task ZeroContentLengthNotSetAutomaticallyForHeadRequests(TestServiceContext testContext)
{ {
using (var server = new TestServer(EmptyApp, testContext)) using (var server = new TestServer(TestApp.EmptyApp, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"HEAD / HTTP/1.1", "HEAD / HTTP/1.1",
"", "",
""); "");
@ -542,7 +514,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"POST / HTTP/1.1", "POST / HTTP/1.1",
"Content-Length: 3", "Content-Length: 3",
"", "",
@ -639,7 +611,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 101 Switching Protocols", "HTTP/1.1 101 Switching Protocols",
"Connection: Upgrade", "Connection: Upgrade",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
@ -654,7 +626,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"Connection: keep-alive", "Connection: keep-alive",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 101 Switching Protocols", "HTTP/1.1 101 Switching Protocols",
"Connection: Upgrade", "Connection: Upgrade",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
@ -679,7 +651,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.OnStarting(_ => response.OnStarting(_ =>
{ {
onStartingCalled = true; onStartingCalled = true;
return Task.FromResult<object>(null); return TaskCache.CompletedTask;
}, null); }, null);
// Anything added to the ResponseHeaders dictionary is ignored // Anything added to the ResponseHeaders dictionary is ignored
@ -689,25 +661,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.1", "GET / HTTP/1.1",
"Connection: close", "Connection: close",
"", "",
""); "");
await connection.Receive( await connection.ReceiveForcedEnd(
"HTTP/1.1 500 Internal Server Error", "HTTP/1.1 500 Internal Server Error",
"");
await connection.Receive(
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"", "",
"HTTP/1.1 500 Internal Server Error", "HTTP/1.1 500 Internal Server Error",
""); "Connection: close",
await connection.Receive("Connection: close",
"");
await connection.ReceiveEnd(
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"", "",
@ -803,47 +770,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.Equal(1, testLogger.ApplicationErrorsLogged); Assert.Equal(1, testLogger.ApplicationErrorsLogged);
} }
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosesWhenFinReceived(TestServiceContext testContext)
{
using (var server = new TestServer(AppChunked, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendEnd(
"GET / HTTP/1.1",
"",
"Post / HTTP/1.1",
"Content-Length: 7",
"",
"Goodbye");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 7",
"",
"Goodbye");
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes(TestServiceContext testContext) public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes(TestServiceContext testContext)
{ {
using (var server = new TestServer(AppChunked, testContext)) using (var server = new TestServer(TestApp.EchoAppChunked, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"POST / HTTP/1.1"); "POST / HTTP/1.1");
connection.Shutdown(SocketShutdown.Send);
await connection.ReceiveForcedEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
@ -859,11 +797,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"POST / HTTP/1.1", "POST / HTTP/1.1",
"Content-Length: 7"); "Content-Length: 7");
connection.Shutdown(SocketShutdown.Send);
await connection.ReceiveForcedEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 200 OK", "HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
@ -918,24 +857,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.1", "GET / HTTP/1.1",
"Connection: close",
"", "",
""); "");
await connection.Receive( await connection.ReceiveEnd(
"HTTP/1.1 500 Internal Server Error", "HTTP/1.1 500 Internal Server Error",
"");
await connection.Receive(
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"", "",
"HTTP/1.1 500 Internal Server Error", "HTTP/1.1 500 Internal Server Error",
"Connection: close",
"");
await connection.ReceiveEnd(
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"", "",
@ -1072,7 +1005,6 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.Equal(2, abortedRequestId); Assert.Equal(2, abortedRequestId);
} }
[Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task FailedWritesResultInAbortedRequest(TestServiceContext testContext) public async Task FailedWritesResultInAbortedRequest(TestServiceContext testContext)
{ {
@ -1215,7 +1147,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.1", "GET / HTTP/1.1",
@ -1262,7 +1194,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
"GET / HTTP/1.1", "GET / HTTP/1.1",
@ -1290,13 +1222,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
var path = httpContext.Request.Path.Value; var path = httpContext.Request.Path.Value;
httpContext.Response.Headers["Content-Length"] = new[] {path.Length.ToString() }; httpContext.Response.Headers["Content-Length"] = new[] { path.Length.ToString() };
await httpContext.Response.WriteAsync(path); await httpContext.Response.WriteAsync(path);
})) }))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
$"GET {inputPath} HTTP/1.1", $"GET {inputPath} HTTP/1.1",
"", "",
""); "");
@ -1337,7 +1269,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -1348,11 +1280,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"", "",
"hello, world"); "hello, world");
}
}
Assert.Equal(1, callOrder.Pop()); Assert.Equal(1, callOrder.Pop());
Assert.Equal(2, callOrder.Pop()); Assert.Equal(2, callOrder.Pop());
} }
}
}
[Theory] [Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
@ -1381,7 +1315,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
@ -1391,47 +1325,47 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
$"Content-Length: {response.Length}", $"Content-Length: {response.Length}",
"", "",
"hello, world"); "hello, world");
}
}
Assert.Equal(1, callOrder.Pop()); Assert.Equal(1, callOrder.Pop());
Assert.Equal(2, callOrder.Pop()); Assert.Equal(2, callOrder.Pop());
} }
}
}
[Theory] [Theory]
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task UpgradeRequestIsNotKeptAliveOrChunked(TestServiceContext testContext) public async Task UpgradeRequestIsNotKeptAliveOrChunked(TestServiceContext testContext)
{ {
const string message = "Hello World";
using (var server = new TestServer(async context => using (var server = new TestServer(async context =>
{ {
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>(); var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
var duplexStream = await upgradeFeature.UpgradeAsync(); var duplexStream = await upgradeFeature.UpgradeAsync();
while (true) var buffer = new byte[message.Length];
var read = 0;
while (read < message.Length)
{ {
var buffer = new byte[8192]; read += await duplexStream.ReadAsync(buffer, read, buffer.Length - read).TimeoutAfter(TimeSpan.FromSeconds(10));
var count = await duplexStream.ReadAsync(buffer, 0, buffer.Length);
if (count == 0)
{
break;
}
await duplexStream.WriteAsync(buffer, 0, count);
} }
await duplexStream.WriteAsync(buffer, 0, read);
}, testContext)) }, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET / HTTP/1.1", "GET / HTTP/1.1",
"Connection: Upgrade", "Connection: Upgrade",
"", "",
"Hello World"); message);
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
"HTTP/1.1 101 Switching Protocols", "HTTP/1.1 101 Switching Protocols",
"Connection: Upgrade", "Connection: Upgrade",
$"Date: {testContext.DateHeaderValue}", $"Date: {testContext.DateHeaderValue}",
"", "",
"Hello World"); message);
} }
} }
} }

View File

@ -6,6 +6,7 @@ using System.IO;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel; using Microsoft.AspNetCore.Server.Kestrel;
using Microsoft.AspNetCore.Server.Kestrel.Internal; using Microsoft.AspNetCore.Server.Kestrel.Internal;
@ -22,10 +23,23 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
private readonly SocketInput _socketInput; private readonly SocketInput _socketInput;
private readonly MemoryPool _pool; private readonly MemoryPool _pool;
private readonly Frame<object> _frame; private readonly TestFrame<object> _frame;
private readonly ServiceContext _serviceContext; private readonly ServiceContext _serviceContext;
private readonly ConnectionContext _connectionContext; private readonly ConnectionContext _connectionContext;
private class TestFrame<TContext> : Frame<TContext>
{
public TestFrame(IHttpApplication<TContext> application, ConnectionContext context)
: base(application, context)
{
}
public Task ProduceEndAsync()
{
return ProduceEnd();
}
}
public FrameTests() public FrameTests()
{ {
var trace = new KestrelTrace(new TestKestrelTrace()); var trace = new KestrelTrace(new TestKestrelTrace());
@ -50,7 +64,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
ConnectionControl = Mock.Of<IConnectionControl>() ConnectionControl = Mock.Of<IConnectionControl>()
}; };
_frame = new Frame<object>(application: null, context: _connectionContext); _frame = new TestFrame<object>(application: null, context: _connectionContext);
_frame.Reset(); _frame.Reset();
_frame.InitializeHeaders(); _frame.InitializeHeaders();
} }
@ -713,5 +727,73 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10));
_socketInput.IncomingFin(); _socketInput.IncomingFin();
} }
[Fact]
public void RequestAbortedTokenIsResetBeforeLastWriteWithContentLength()
{
_frame.ResponseHeaders["Content-Length"] = "12";
// Need to compare WaitHandle ref since CancellationToken is struct
var original = _frame.RequestAborted.WaitHandle;
foreach (var ch in "hello, worl")
{
_frame.Write(new ArraySegment<byte>(new[] { (byte)ch }));
Assert.Same(original, _frame.RequestAborted.WaitHandle);
}
_frame.Write(new ArraySegment<byte>(new[] { (byte)'d' }));
Assert.NotSame(original, _frame.RequestAborted.WaitHandle);
}
[Fact]
public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncWithContentLength()
{
_frame.ResponseHeaders["Content-Length"] = "12";
// Need to compare WaitHandle ref since CancellationToken is struct
var original = _frame.RequestAborted.WaitHandle;
foreach (var ch in "hello, worl")
{
await _frame.WriteAsync(new ArraySegment<byte>(new[] { (byte)ch }), default(CancellationToken));
Assert.Same(original, _frame.RequestAborted.WaitHandle);
}
await _frame.WriteAsync(new ArraySegment<byte>(new[] { (byte)'d' }), default(CancellationToken));
Assert.NotSame(original, _frame.RequestAborted.WaitHandle);
}
[Fact]
public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncAwaitedWithContentLength()
{
_frame.ResponseHeaders["Content-Length"] = "12";
// Need to compare WaitHandle ref since CancellationToken is struct
var original = _frame.RequestAborted.WaitHandle;
foreach (var ch in "hello, worl")
{
await _frame.WriteAsyncAwaited(new ArraySegment<byte>(new[] { (byte)ch }), default(CancellationToken));
Assert.Same(original, _frame.RequestAborted.WaitHandle);
}
await _frame.WriteAsyncAwaited(new ArraySegment<byte>(new[] { (byte)'d' }), default(CancellationToken));
Assert.NotSame(original, _frame.RequestAborted.WaitHandle);
}
[Fact]
public async Task RequestAbortedTokenIsResetBeforeLastWriteWithChunkedEncoding()
{
// Need to compare WaitHandle ref since CancellationToken is struct
var original = _frame.RequestAborted.WaitHandle;
_frame.HttpVersion = "HTTP/1.1";
await _frame.WriteAsync(new ArraySegment<byte>(Encoding.ASCII.GetBytes("hello, world")), default(CancellationToken));
Assert.Same(original, _frame.RequestAborted.WaitHandle);
await _frame.ProduceEndAsync();
Assert.NotSame(original, _frame.RequestAborted.WaitHandle);
}
} }
} }

View File

@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
"GET /%41%CC%8A/A/../B/%41%CC%8A HTTP/1.1", "GET /%41%CC%8A/A/../B/%41%CC%8A HTTP/1.1",
"", "",
""); "");
@ -71,7 +71,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
$"GET {requestTarget} HTTP/1.1", $"GET {requestTarget} HTTP/1.1",
"", "",
""); "");
@ -115,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.SendEnd( await connection.Send(
$"GET {requestTarget} HTTP/1.1", $"GET {requestTarget} HTTP/1.1",
"", "",
""); "");

View File

@ -12,20 +12,27 @@ namespace Microsoft.AspNetCore.Testing
public class DummyApplication : IHttpApplication<HttpContext> public class DummyApplication : IHttpApplication<HttpContext>
{ {
private readonly RequestDelegate _requestDelegate; private readonly RequestDelegate _requestDelegate;
private readonly IHttpContextFactory _httpContextFactory;
public DummyApplication(RequestDelegate requestDelegate) public DummyApplication(RequestDelegate requestDelegate)
: this(requestDelegate, null)
{
}
public DummyApplication(RequestDelegate requestDelegate, IHttpContextFactory httpContextFactory)
{ {
_requestDelegate = requestDelegate; _requestDelegate = requestDelegate;
_httpContextFactory = httpContextFactory;
} }
public HttpContext CreateContext(IFeatureCollection contextFeatures) public HttpContext CreateContext(IFeatureCollection contextFeatures)
{ {
return new DefaultHttpContext(contextFeatures); return _httpContextFactory?.Create(contextFeatures) ?? new DefaultHttpContext(contextFeatures);
} }
public void DisposeContext(HttpContext context, Exception exception) public void DisposeContext(HttpContext context, Exception exception)
{ {
_httpContextFactory?.Dispose(context);
} }
public async Task ProcessRequestAsync(HttpContext context) public async Task ProcessRequestAsync(HttpContext context)

49
test/shared/TestApp.cs Normal file
View File

@ -0,0 +1,49 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Testing
{
public static class TestApp
{
public static async Task EchoApp(HttpContext httpContext)
{
var request = httpContext.Request;
var response = httpContext.Response;
var buffer = new byte[httpContext.Request.ContentLength ?? 0];
var bytesRead = 0;
while (bytesRead < buffer.Length)
{
var count = await request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead);
bytesRead += count;
}
if (buffer.Length > 0)
{
await response.Body.WriteAsync(buffer, 0, buffer.Length);
}
}
public static async Task EchoAppChunked(HttpContext httpContext)
{
var request = httpContext.Request;
var response = httpContext.Response;
var data = new MemoryStream();
await request.Body.CopyToAsync(data);
var bytes = data.ToArray();
response.Headers["Content-Length"] = bytes.Length.ToString();
await response.Body.WriteAsync(bytes, 0, bytes.Length);
}
public static Task EmptyApp(HttpContext httpContext)
{
return TaskCache.CompletedTask;
}
}
}

View File

@ -50,22 +50,6 @@ namespace Microsoft.AspNetCore.Testing
_stream.Flush(); _stream.Flush();
} }
public async Task SendAllTryEnd(params string[] lines)
{
await SendAll(lines);
try
{
_socket.Shutdown(SocketShutdown.Send);
}
catch (IOException)
{
// The server may forcefully close the connection (usually due to a bad request),
// so an IOException: "An existing connection was forcibly closed by the remote host"
// isn't guaranteed but not unexpected.
}
}
public async Task Send(params string[] lines) public async Task Send(params string[] lines)
{ {
var text = string.Join("\r\n", lines); var text = string.Join("\r\n", lines);
@ -82,12 +66,6 @@ namespace Microsoft.AspNetCore.Testing
_stream.Flush(); _stream.Flush();
} }
public async Task SendEnd(params string[] lines)
{
await Send(lines);
_socket.Shutdown(SocketShutdown.Send);
}
public async Task Receive(params string[] lines) public async Task Receive(params string[] lines)
{ {
var expected = string.Join("\r\n", lines); var expected = string.Join("\r\n", lines);
@ -95,6 +73,7 @@ namespace Microsoft.AspNetCore.Testing
var offset = 0; var offset = 0;
while (offset < expected.Length) while (offset < expected.Length)
{ {
var data = new byte[expected.Length];
var task = _reader.ReadAsync(actual, offset, actual.Length - offset); var task = _reader.ReadAsync(actual, offset, actual.Length - offset);
if (!Debugger.IsAttached) if (!Debugger.IsAttached)
{ {
@ -108,12 +87,13 @@ namespace Microsoft.AspNetCore.Testing
offset += count; offset += count;
} }
Assert.Equal(expected, new String(actual, 0, offset)); Assert.Equal(expected, new string(actual, 0, offset));
} }
public async Task ReceiveEnd(params string[] lines) public async Task ReceiveEnd(params string[] lines)
{ {
await Receive(lines); await Receive(lines);
_socket.Shutdown(SocketShutdown.Send);
var ch = new char[128]; var ch = new char[128];
var count = await _reader.ReadAsync(ch, 0, 128).TimeoutAfter(TimeSpan.FromMinutes(1)); var count = await _reader.ReadAsync(ch, 0, 128).TimeoutAfter(TimeSpan.FromMinutes(1));
var text = new string(ch, 0, count); var text = new string(ch, 0, count);
@ -139,11 +119,16 @@ namespace Microsoft.AspNetCore.Testing
} }
} }
public void Shutdown(SocketShutdown how)
{
_socket.Shutdown(how);
}
public Task WaitForConnectionClose() public Task WaitForConnectionClose()
{ {
var tcs = new TaskCompletionSource<object>(); var tcs = new TaskCompletionSource<object>();
var eventArgs = new SocketAsyncEventArgs(); var eventArgs = new SocketAsyncEventArgs();
eventArgs.SetBuffer(new byte[1], 0, 1); eventArgs.SetBuffer(new byte[128], 0, 128);
eventArgs.Completed += ReceiveAsyncCompleted; eventArgs.Completed += ReceiveAsyncCompleted;
eventArgs.UserToken = tcs; eventArgs.UserToken = tcs;
@ -157,11 +142,16 @@ namespace Microsoft.AspNetCore.Testing
private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e)
{ {
var tcs = (TaskCompletionSource<object>)e.UserToken;
if (e.BytesTransferred == 0) if (e.BytesTransferred == 0)
{ {
var tcs = (TaskCompletionSource<object>)e.UserToken;
tcs.SetResult(null); tcs.SetResult(null);
} }
else
{
tcs.SetException(new IOException(
$"Expected connection close, received data instead: \"{_reader.CurrentEncoding.GetString(e.Buffer, 0, e.BytesTransferred)}\""));
}
} }
public static Socket CreateConnectedLoopbackSocket(int port) public static Socket CreateConnectedLoopbackSocket(int port)

View File

@ -29,12 +29,17 @@ namespace Microsoft.AspNetCore.Testing
} }
public TestServer(RequestDelegate app, TestServiceContext context, string serverAddress) public TestServer(RequestDelegate app, TestServiceContext context, string serverAddress)
: this(app, context, serverAddress, null)
{
}
public TestServer(RequestDelegate app, TestServiceContext context, string serverAddress, IHttpContextFactory httpContextFactory)
{ {
Context = context; Context = context;
context.FrameFactory = connectionContext => context.FrameFactory = connectionContext =>
{ {
return new Frame<HttpContext>(new DummyApplication(app), connectionContext); return new Frame<HttpContext>(new DummyApplication(app, httpContextFactory), connectionContext);
}; };
try try