Call OnStarting before verifying response length (#1303).

- Also don't close connection when Content-Length set but no bytes written.
This commit is contained in:
Cesar Blum Silveira 2017-01-25 15:07:48 -08:00
parent e2a2e9a620
commit c11aedd272
3 changed files with 252 additions and 41 deletions

View File

@ -508,23 +508,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
public void Flush() public void Flush()
{ {
ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); InitializeResponse(0).GetAwaiter().GetResult();
SocketOutput.Flush(); SocketOutput.Flush();
} }
public async Task FlushAsync(CancellationToken cancellationToken) public async Task FlushAsync(CancellationToken cancellationToken)
{ {
await ProduceStartAndFireOnStarting(); await InitializeResponse(0);
await SocketOutput.FlushAsync(cancellationToken); await SocketOutput.FlushAsync(cancellationToken);
} }
public void Write(ArraySegment<byte> data) public void Write(ArraySegment<byte> data)
{ {
// For the first write, ensure headers are flushed if Write(Chunked)isn't called. // For the first write, ensure headers are flushed if Write(Chunked) isn't called.
var firstWrite = !HasResponseStarted; var firstWrite = !HasResponseStarted;
VerifyAndUpdateWrite(data.Count); if (firstWrite)
ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); {
InitializeResponse(data.Count).GetAwaiter().GetResult();
}
else
{
VerifyAndUpdateWrite(data.Count);
}
if (_canHaveBody) if (_canHaveBody)
{ {
@ -589,9 +595,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
public async Task WriteAsyncAwaited(ArraySegment<byte> data, CancellationToken cancellationToken) public async Task WriteAsyncAwaited(ArraySegment<byte> data, CancellationToken cancellationToken)
{ {
VerifyAndUpdateWrite(data.Count); await InitializeResponseAwaited(data.Count);
await ProduceStartAndFireOnStarting();
// WriteAsyncAwaited is only called for the first write to the body. // WriteAsyncAwaited is only called for the first write to the body.
// Ensure headers are flushed if Write(Chunked)Async isn't called. // Ensure headers are flushed if Write(Chunked)Async isn't called.
@ -645,7 +649,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
responseHeaders.HeaderContentLengthValue.HasValue && responseHeaders.HeaderContentLengthValue.HasValue &&
_responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value) _responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value)
{ {
_keepAlive = false; // We need to close the connection if any bytes were written since the client
// cannot be certain of how many bytes it will receive.
if (_responseBytesWritten > 0)
{
_keepAlive = false;
}
ReportApplicationError(new InvalidOperationException( ReportApplicationError(new InvalidOperationException(
$"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value}).")); $"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value})."));
} }
@ -688,7 +698,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
} }
} }
public Task ProduceStartAndFireOnStarting() public Task InitializeResponse(int firstWriteByteCount)
{ {
if (HasResponseStarted) if (HasResponseStarted)
{ {
@ -697,7 +707,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
if (_onStarting != null) if (_onStarting != null)
{ {
return ProduceStartAndFireOnStartingAwaited(); return InitializeResponseAwaited(firstWriteByteCount);
} }
if (_applicationException != null) if (_applicationException != null)
@ -705,11 +715,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
ThrowResponseAbortedException(); ThrowResponseAbortedException();
} }
VerifyAndUpdateWrite(firstWriteByteCount);
ProduceStart(appCompleted: false); ProduceStart(appCompleted: false);
return TaskCache.CompletedTask; return TaskCache.CompletedTask;
} }
private async Task ProduceStartAndFireOnStartingAwaited() private async Task InitializeResponseAwaited(int firstWriteByteCount)
{ {
await FireOnStarting(); await FireOnStarting();
@ -718,6 +730,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
ThrowResponseAbortedException(); ThrowResponseAbortedException();
} }
VerifyAndUpdateWrite(firstWriteByteCount);
ProduceStart(appCompleted: false); ProduceStart(appCompleted: false);
} }

View File

@ -480,7 +480,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
[Fact] [Fact]
public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses() public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWrite()
{ {
var testLogger = new TestApplicationErrorLogger(); var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
@ -515,7 +515,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
[Fact] [Fact]
public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses() public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWriteAsync()
{ {
var testLogger = new TestApplicationErrorLogger(); var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
@ -549,15 +549,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
[Fact] [Fact]
public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses() public async Task InternalServerErrorAndConnectionClosedOnWriteWithMoreThanContentLengthAndResponseNotStarted()
{ {
var testLogger = new TestApplicationErrorLogger(); var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
using (var server = new TestServer(async httpContext => using (var server = new TestServer(httpContext =>
{ {
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = 5; httpContext.Response.ContentLength = 5;
await httpContext.Response.WriteAsync("hello, world"); httpContext.Response.Body.Write(response, 0, response.Length);
return TaskCache.CompletedTask;
}, serviceContext)) }, serviceContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
@ -566,7 +568,42 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error",
"Connection: close",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
}
var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error);
Assert.Equal(
$"Response Content-Length mismatch: too many bytes written (12 of 5).",
logMessage.Exception.Message);
}
[Fact]
public async Task InternalServerErrorAndConnectionClosedOnWriteAsyncWithMoreThanContentLengthAndResponseNotStarted()
{
var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
using (var server = new TestServer(httpContext =>
{
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = 5;
return httpContext.Response.Body.WriteAsync(response, 0, response.Length);
}, serviceContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.ReceiveForcedEnd(
$"HTTP/1.1 500 Internal Server Error", $"HTTP/1.1 500 Internal Server Error",
"Connection: close", "Connection: close",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
@ -616,7 +653,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
[Fact] [Fact]
public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionCloses() public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionDoesNotClose()
{ {
var testLogger = new TestApplicationErrorLogger(); var testLogger = new TestApplicationErrorLogger();
var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) };
@ -630,12 +667,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
{ {
await connection.Send( await connection.Send(
"GET / HTTP/1.1",
"",
"GET / HTTP/1.1", "GET / HTTP/1.1",
"", "",
""); "");
await connection.ReceiveEnd( await connection.Receive(
$"HTTP/1.1 500 Internal Server Error", "HTTP/1.1 500 Internal Server Error",
"Connection: close", $"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"HTTP/1.1 500 Internal Server Error",
$"Date: {server.Context.DateHeaderValue}", $"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0", "Content-Length: 0",
"", "",
@ -643,10 +685,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
} }
var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); var error = testLogger.Messages.Where(message => message.LogLevel == LogLevel.Error);
Assert.Equal( Assert.Equal(2, error.Count());
$"Response Content-Length mismatch: too few bytes written (0 of 5).", Assert.All(error, message => message.Equals("Response Content-Length mismatch: too few bytes written (0 of 5)."));
errorMessage.Exception.Message);
} }
[Theory] [Theory]
@ -1050,6 +1091,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
} }
} }
[Fact]
public async Task FirstWriteVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;
// If OnStarting is not run before verifying writes, an error response will be sent.
httpContext.Response.Body.Write(response, 0, response.Length);
return TaskCache.CompletedTask;
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}
[Fact]
public async Task SubsequentWriteVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;
// If OnStarting is not run before verifying writes, an error response will be sent.
httpContext.Response.Body.Write(response, 0, response.Length / 2);
httpContext.Response.Body.Write(response, response.Length / 2, response.Length - response.Length / 2);
return TaskCache.CompletedTask;
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"6",
"hello,",
"6",
" world",
"0",
"",
"");
}
}
}
[Fact]
public async Task FirstWriteAsyncVerifiedAfterOnStarting()
{
using (var server = new TestServer(httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;
// If OnStarting is not run before verifying writes, an error response will be sent.
return httpContext.Response.Body.WriteAsync(response, 0, response.Length);
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}
[Fact]
public async Task SubsequentWriteAsyncVerifiedAfterOnStarting()
{
using (var server = new TestServer(async httpContext =>
{
httpContext.Response.OnStarting(() =>
{
// Change response to chunked
httpContext.Response.ContentLength = null;
return TaskCache.CompletedTask;
});
var response = Encoding.ASCII.GetBytes("hello, world");
httpContext.Response.ContentLength = response.Length - 1;
// If OnStarting is not run before verifying writes, an error response will be sent.
await httpContext.Response.Body.WriteAsync(response, 0, response.Length / 2);
await httpContext.Response.Body.WriteAsync(response, response.Length / 2, response.Length - response.Length / 2);
}, new TestServiceContext()))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
$"Transfer-Encoding: chunked",
"",
"6",
"hello,",
"6",
" world",
"0",
"",
"");
}
}
}
public static TheoryData<string, StringValues, string> NullHeaderData public static TheoryData<string, StringValues, string> NullHeaderData
{ {
get get

View File

@ -883,9 +883,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[MemberData(nameof(ConnectionFilterData))] [MemberData(nameof(ConnectionFilterData))]
public async Task ThrowingInOnStartingResultsInFailedWritesAnd500Response(TestServiceContext testContext) public async Task ThrowingInOnStartingResultsInFailedWritesAnd500Response(TestServiceContext testContext)
{ {
var onStartingCallCount1 = 0; var callback1Called = false;
var onStartingCallCount2 = 0; var callback2CallCount = 0;
var failedWriteCount = 0;
var testLogger = new TestApplicationErrorLogger(); var testLogger = new TestApplicationErrorLogger();
testContext.Log = new KestrelTrace(testLogger); testContext.Log = new KestrelTrace(testLogger);
@ -897,23 +896,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
var response = httpContext.Response; var response = httpContext.Response;
response.OnStarting(_ => response.OnStarting(_ =>
{ {
onStartingCallCount1++; callback1Called = true;
throw onStartingException; throw onStartingException;
}, null); }, null);
response.OnStarting(_ => response.OnStarting(_ =>
{ {
onStartingCallCount2++; callback2CallCount++;
throw onStartingException; throw onStartingException;
}, null); }, null);
response.Headers["Content-Length"] = new[] { "11" }; var writeException = await Assert.ThrowsAsync<ObjectDisposedException>(async () => await response.Body.FlushAsync());
var writeException = await Assert.ThrowsAsync<ObjectDisposedException>(async () =>
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11));
Assert.Same(onStartingException, writeException.InnerException); Assert.Same(onStartingException, writeException.InnerException);
failedWriteCount++;
}, testContext)) }, testContext))
{ {
using (var connection = server.CreateConnection()) using (var connection = server.CreateConnection())
@ -943,10 +936,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
// The first registered OnStarting callback should not be called, // The first registered OnStarting callback should not have been called,
// since they are called LIFO and the other one failed. // since they are called LIFO and the other one failed.
Assert.Equal(0, onStartingCallCount1); Assert.False(callback1Called);
Assert.Equal(2, onStartingCallCount2); Assert.Equal(2, callback2CallCount);
Assert.Equal(2, testLogger.ApplicationErrorsLogged); Assert.Equal(2, testLogger.ApplicationErrorsLogged);
} }