From f8813a600df30e018a6224048b31244be3d456b9 Mon Sep 17 00:00:00 2001 From: Cesar Blum Silveira Date: Fri, 7 Oct 2016 10:23:07 -0700 Subject: [PATCH] Handle response content length mismatches (#175). --- .../Internal/Http/Frame.cs | 27 +- .../Internal/Http/FrameHeaders.Generated.cs | 6 + .../Internal/Http/FrameHeaders.cs | 13 + .../Internal/Http/FrameOfT.cs | 10 + .../Internal/Http/FrameResponseHeaders.cs | 4 + .../Internal/Infrastructure/IKestrelTrace.cs | 2 +- .../Internal/Infrastructure/KestrelTrace.cs | 6 +- .../ResponseTests.cs | 286 +++++++++++++++++- .../EngineTests.cs | 71 ++--- .../FrameResponseHeadersTests.cs | 93 +++++- test/shared/TestApplicationErrorLogger.cs | 30 +- test/shared/TestKestrelTrace.cs | 10 - .../KnownHeaders.cs | 22 +- 13 files changed, 482 insertions(+), 98 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index f85940605c..8312e88bef 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Globalization; using System.IO; using System.Linq; using System.Net; @@ -75,7 +76,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http protected readonly long _keepAliveMilliseconds; private readonly long _requestHeadersTimeoutMilliseconds; - private int _responseBytesWritten; + protected long _responseBytesWritten; public Frame(ConnectionContext context) { @@ -516,8 +517,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public void Write(ArraySegment data) { + VerifyAndUpdateWrite(data.Count); ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); - _responseBytesWritten += data.Count; if (_canHaveBody) { @@ -547,7 +548,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return WriteAsyncAwaited(data, cancellationToken); } - _responseBytesWritten += data.Count; + VerifyAndUpdateWrite(data.Count); if (_canHaveBody) { @@ -573,8 +574,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken cancellationToken) { + VerifyAndUpdateWrite(data.Count); + await ProduceStartAndFireOnStarting(); - _responseBytesWritten += data.Count; if (_canHaveBody) { @@ -598,6 +600,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } + private void VerifyAndUpdateWrite(int count) + { + var responseHeaders = FrameResponseHeaders; + + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.HasContentLength && + _responseBytesWritten + count > responseHeaders.HeaderContentLengthValue.Value) + { + _keepAlive = false; + throw new InvalidOperationException( + $"Response Content-Length mismatch: too many bytes written ({_responseBytesWritten + count} of {responseHeaders.HeaderContentLengthValue.Value})."); + } + + _responseBytesWritten += count; + } + private void WriteChunked(ArraySegment data) { SocketOutput.Write(data, chunk: true); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs index d37c1c8cb6..88fa8d6ac7 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs @@ -3697,6 +3697,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { _bits = 0; _headers = default(HeaderReferences); + MaybeUnknown?.Clear(); } @@ -5670,6 +5671,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } set { + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -7384,6 +7386,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) { + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -7809,6 +7812,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { ThrowDuplicateKeyException(); } + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -8350,6 +8354,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { if (((_bits & 2048L) != 0)) { + _contentLength = null; _bits &= ~2048L; _headers._ContentLength = StringValues.Empty; _headers._rawContentLength = null; @@ -8601,6 +8606,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { _bits = 0; _headers = default(HeaderReferences); + _contentLength = null; MaybeUnknown?.Clear(); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs index ac283a006c..c6b9d0c59d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Globalization; using System.Linq; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; @@ -232,6 +233,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } + public static long ParseContentLength(StringValues value) + { + try + { + return long.Parse(value, NumberStyles.AllowLeadingWhite | NumberStyles.AllowTrailingWhite, CultureInfo.InvariantCulture); + } + catch (FormatException ex) + { + throw new InvalidOperationException("Content-Length value must be an integral number.", ex); + } + } + private static void ThrowInvalidHeaderCharacter(char ch) { throw new InvalidOperationException(string.Format("Invalid non-ASCII or control character in header: 0x{0:X4}", (ushort)ch)); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs index eeba9695df..5fc2a56d2f 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs @@ -92,6 +92,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http try { await _application.ProcessRequestAsync(context).ConfigureAwait(false); + + var responseHeaders = FrameResponseHeaders; + if (!responseHeaders.HasTransferEncoding && + responseHeaders.HasContentLength && + _responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value) + { + _keepAlive = false; + ReportApplicationError(new InvalidOperationException( + $"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value}).")); + } } catch (Exception ex) { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs index 9ea8056f27..ccb5e0d463 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs @@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private static readonly byte[] _CrLf = new[] { (byte)'\r', (byte)'\n' }; private static readonly byte[] _colonSpace = new[] { (byte)':', (byte)' ' }; + private long? _contentLength; + public bool HasConnection => HeaderConnection.Count != 0; public bool HasTransferEncoding => HeaderTransferEncoding.Count != 0; @@ -23,6 +25,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public bool HasDate => HeaderDate.Count != 0; + public long? HeaderContentLengthValue => _contentLength; + public Enumerator GetEnumerator() { return new Enumerator(this); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs index 900adc8766..2be2acbfc3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure void ConnectionDisconnectedWrite(string connectionId, int count, Exception ex); - void ConnectionHeadResponseBodyWrite(string connectionId, int count); + void ConnectionHeadResponseBodyWrite(string connectionId, long count); void ConnectionBadRequest(string connectionId, BadHttpRequestException ex); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs index 5d3ecff7bd..fcdc3f9208 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs @@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal private static readonly Action _applicationError; private static readonly Action _connectionError; private static readonly Action _connectionDisconnectedWrite; - private static readonly Action _connectionHeadResponseBodyWrite; + private static readonly Action _connectionHeadResponseBodyWrite; private static readonly Action _notAllConnectionsClosedGracefully; private static readonly Action _connectionBadRequest; @@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal _connectionDisconnectedWrite = LoggerMessage.Define(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client."); _notAllConnectionsClosedGracefully = LoggerMessage.Define(LogLevel.Debug, 16, "Some connections failed to close gracefully during server shutdown."); _connectionBadRequest = LoggerMessage.Define(LogLevel.Information, 17, @"Connection id ""{ConnectionId}"" bad request data: ""{message}"""); - _connectionHeadResponseBodyWrite = LoggerMessage.Define(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response."); + _connectionHeadResponseBodyWrite = LoggerMessage.Define(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response."); } public KestrelTrace(ILogger logger) @@ -135,7 +135,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal _connectionDisconnectedWrite(_logger, connectionId, count, ex); } - public virtual void ConnectionHeadResponseBodyWrite(string connectionId, int count) + public virtual void ConnectionHeadResponseBodyWrite(string connectionId, long count) { _connectionHeadResponseBodyWrite(_logger, connectionId, count, null); } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index bcd2e7ecfd..57e753d071 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.Linq; using System.Net; using System.Net.Http; @@ -14,6 +15,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; using Moq; using Xunit; @@ -85,7 +87,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests app.Run(async context => { context.Response.Headers.Add(headerName, headerValue); - + await context.Response.WriteAsync(""); }); }); @@ -299,7 +301,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task ResponseBodyNotWrittenOnHeadResponse() + public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce() { var mockKestrelTrace = new Mock(); @@ -324,7 +326,285 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } mockKestrelTrace.Verify(kestrelTrace => - kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), "hello, world".Length)); + kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), "hello, world".Length), Times.Once); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 11; + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes("hello,"), 0, 6); + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes(" world"), 0, 6); + return TaskCache.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello,"); + await httpContext.Response.WriteAsync(" world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 5; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesLessThanContentLengthErrorLogged() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 13", + "", + "hello, world"); + } + } + + var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too few bytes written (12 of 13).", + errorMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 5; + return TaskCache.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too few bytes written (0 of 5).", + errorMessage.Exception.Message); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task WhenAppSetsContentLengthToZeroAndDoesNotWriteNoErrorIsThrown(bool flushResponse) + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 0; + + if (flushResponse) + { + await httpContext.Response.Body.FlushAsync(); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingLessIsNotAnError() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "Content-Length: 13", + "", + "hello, world"); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingMoreIsNotAnError() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "Content-Length: 11", + "", + "hello, world"); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); } public static TheoryData NullHeaderData diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index a23a063d12..f9a8b8677d 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -620,35 +620,6 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Theory] - [MemberData(nameof(ConnectionFilterData))] - public async Task WriteOnHeadResponseLoggedOnlyOnce(TestServiceContext testContext) - { - using (var server = new TestServer(async httpContext => - { - await httpContext.Response.WriteAsync("hello, "); - await httpContext.Response.WriteAsync("world"); - await httpContext.Response.WriteAsync("!"); - }, testContext)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendEnd( - "HEAD / HTTP/1.1", - "", - ""); - await connection.ReceiveEnd( - "HTTP/1.1 200 OK", - $"Date: {testContext.DateHeaderValue}", - "", - ""); - } - - Assert.Equal(1, ((TestKestrelTrace)testContext.Log).HeadResponseWrites); - Assert.Equal(13, ((TestKestrelTrace)testContext.Log).HeadResponseWriteByteCount); - } - } - [Theory] [MemberData(nameof(ConnectionFilterData))] public async Task ThrowingResultsIn500Response(TestServiceContext testContext) @@ -697,11 +668,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Content-Length: 0", "", ""); - - Assert.False(onStartingCalled); - Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } + + Assert.False(onStartingCalled); + Assert.Equal(2, testLogger.ApplicationErrorsLogged); } [Theory] @@ -739,11 +710,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Content-Length: 11", "", "Hello World"); - - Assert.True(onStartingCalled); - Assert.Equal(1, testLogger.ApplicationErrorsLogged); } } + + Assert.True(onStartingCalled); + Assert.Equal(1, testLogger.ApplicationErrorsLogged); } [Theory] @@ -781,11 +752,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Content-Length: 11", "", "Hello"); - - Assert.True(onStartingCalled); - Assert.Equal(1, testLogger.ApplicationErrorsLogged); } } + + Assert.True(onStartingCalled); + Assert.Equal(1, testLogger.ApplicationErrorsLogged); } [Theory] @@ -925,16 +896,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Content-Length: 0", "", ""); - - Assert.Equal(2, onStartingCallCount2); - - // The first registered OnStarting callback should not be called, - // since they are called LIFO and the other one failed. - Assert.Equal(0, onStartingCallCount1); - - Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } + + // The first registered OnStarting callback should not be called, + // since they are called LIFO and the other one failed. + Assert.Equal(0, onStartingCallCount1); + Assert.Equal(2, onStartingCallCount2); + Assert.Equal(2, testLogger.ApplicationErrorsLogged); } [Theory] @@ -979,12 +948,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "", "Hello World"); } - - // All OnCompleted callbacks should be called even if they throw. - Assert.Equal(2, testLogger.ApplicationErrorsLogged); - Assert.True(onCompletedCalled1); - Assert.True(onCompletedCalled2); } + + // All OnCompleted callbacks should be called even if they throw. + Assert.Equal(2, testLogger.ApplicationErrorsLogged); + Assert.True(onCompletedCalled1); + Assert.True(onCompletedCalled2); } [Theory] diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs index c57ddff257..ff27b37cf9 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs @@ -78,24 +78,29 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { var responseHeaders = new FrameResponseHeaders(); - Assert.Throws(() => { + Assert.Throws(() => + { ((IHeaderDictionary)responseHeaders)[key] = value; }); - Assert.Throws(() => { + Assert.Throws(() => + { ((IHeaderDictionary)responseHeaders)[key] = new StringValues(new[] { "valid", value }); }); - Assert.Throws(() => { + Assert.Throws(() => + { ((IDictionary)responseHeaders)[key] = value; }); - Assert.Throws(() => { + Assert.Throws(() => + { var kvp = new KeyValuePair(key, value); ((ICollection>)responseHeaders).Add(kvp); }); - Assert.Throws(() => { + Assert.Throws(() => + { var kvp = new KeyValuePair(key, value); ((IDictionary)responseHeaders).Add(key, value); }); @@ -142,5 +147,83 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Throws(() => dictionary.Clear()); } + + [Fact] + public void ThrowsWhenAddingContentLengthWithNonNumericValue() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + + Assert.Throws(() => dictionary.Add("Content-Length", new[] { "bad" })); + } + + [Fact] + public void ThrowsWhenSettingContentLengthToNonNumericValue() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + + Assert.Throws(() => ((IHeaderDictionary)headers)["Content-Length"] = "bad"); + } + + [Fact] + public void ThrowsWhenAssigningHeaderContentLengthToNonNumericValue() + { + var headers = new FrameResponseHeaders(); + Assert.Throws(() => headers.HeaderContentLength = "bad"); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterAddingHeader() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("Content-Length", "42"); + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterSettingHeader() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary["Content-Length"] = "42"; + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterAssigningHeader() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueClearedWhenHeaderIsRemoved() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Remove("Content-Length"); + + Assert.Equal(null, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueClearedWhenHeadersCleared() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Clear(); + + Assert.Equal(null, headers.HeaderContentLengthValue); + } } } diff --git a/test/shared/TestApplicationErrorLogger.cs b/test/shared/TestApplicationErrorLogger.cs index d2d3731a9d..5036f1cec9 100644 --- a/test/shared/TestApplicationErrorLogger.cs +++ b/test/shared/TestApplicationErrorLogger.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.Server.Kestrel.Internal; using Microsoft.Extensions.Logging; @@ -12,11 +14,13 @@ namespace Microsoft.AspNetCore.Testing // Application errors are logged using 13 as the eventId. private const int ApplicationErrorEventId = 13; - public int TotalErrorsLogged { get; set; } + public List Messages { get; } = new List(); - public int CriticalErrorsLogged { get; set; } + public int TotalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Error); - public int ApplicationErrorsLogged { get; set; } + public int CriticalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Critical); + + public int ApplicationErrorsLogged => Messages.Count(message => message.EventId.Id == ApplicationErrorEventId); public IDisposable BeginScope(TState state) { @@ -34,20 +38,14 @@ namespace Microsoft.AspNetCore.Testing Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}"); #endif - if (eventId.Id == ApplicationErrorEventId) - { - ApplicationErrorsLogged++; - } + Messages.Add(new LogMessage { LogLevel = logLevel, EventId = eventId, Exception = exception }); + } - if (logLevel == LogLevel.Error) - { - TotalErrorsLogged++; - } - - if (logLevel == LogLevel.Critical) - { - CriticalErrorsLogged++; - } + public class LogMessage + { + public LogLevel LogLevel { get; set; } + public EventId EventId { get; set; } + public Exception Exception { get; set; } } } } diff --git a/test/shared/TestKestrelTrace.cs b/test/shared/TestKestrelTrace.cs index 63dbfc0f73..814005d4d1 100644 --- a/test/shared/TestKestrelTrace.cs +++ b/test/shared/TestKestrelTrace.cs @@ -13,10 +13,6 @@ namespace Microsoft.AspNetCore.Testing { } - public int HeadResponseWrites { get; set; } - - public int HeadResponseWriteByteCount { get; set; } - public override void ConnectionRead(string connectionId, int count) { //_logger.LogDebug(1, @"Connection id ""{ConnectionId}"" recv {count} bytes.", connectionId, count); @@ -31,11 +27,5 @@ namespace Microsoft.AspNetCore.Testing { //_logger.LogDebug(1, @"Connection id ""{ConnectionId}"" send finished with status {status}.", connectionId, status); } - - public override void ConnectionHeadResponseBodyWrite(string connectionId, int count) - { - HeadResponseWrites++; - HeadResponseWriteByteCount = count; - } } } \ No newline at end of file diff --git a/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs b/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs index bb4f4485a0..7d8b57256f 100644 --- a/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs +++ b/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs @@ -14,6 +14,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.GeneratedCode return values.Any() ? values.Select(formatter).Aggregate((a, b) => a + b) : ""; } + static string If(bool condition, Func formatter) + { + return condition ? formatter() : ""; + } + class KnownHeader { public string Name { get; set; } @@ -228,7 +233,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return StringValues.Empty; }} set - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value; {(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -304,7 +310,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http case {byLength.Key}: {{{Each(byLength, header => $@" if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -328,7 +335,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if ({header.TestBit()}) {{ ThrowDuplicateKeyException(); - }} + }}{ + If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -349,7 +358,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) {{ if ({header.TestBit()}) - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = null;")} {header.ClearBit()}; _headers._{header.Identifier} = StringValues.Empty;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -369,6 +379,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http {{ _bits = 0; _headers = default(HeaderReferences); + {(loop.ClassName == "FrameResponseHeaders" ? "_contentLength = null;" : "")} MaybeUnknown?.Clear(); }} @@ -435,7 +446,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _headers._{header.Identifier} = AppendValue(_headers._{header.Identifier}, value); }} else - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = new StringValues(value);{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")}