diff --git a/benchmarkapps/PlatformBenchmarks/BenchmarkApplication.cs b/benchmarkapps/PlatformBenchmarks/BenchmarkApplication.cs index 6551bee358..d65452c13c 100644 --- a/benchmarkapps/PlatformBenchmarks/BenchmarkApplication.cs +++ b/benchmarkapps/PlatformBenchmarks/BenchmarkApplication.cs @@ -78,7 +78,7 @@ namespace PlatformBenchmarks } private static void PlainText(PipeWriter pipeWriter) { - var writer = new BufferWriter(pipeWriter); + var writer = new CountingBufferWriter(pipeWriter); // HTTP 1.1 OK writer.Write(_http11OK); @@ -105,7 +105,7 @@ namespace PlatformBenchmarks private static void Json(PipeWriter pipeWriter) { - var writer = new BufferWriter(pipeWriter); + var writer = new CountingBufferWriter(pipeWriter); // HTTP 1.1 OK writer.Write(_http11OK); @@ -134,7 +134,7 @@ namespace PlatformBenchmarks private static void Default(PipeWriter pipeWriter) { - var writer = new BufferWriter(pipeWriter); + var writer = new CountingBufferWriter(pipeWriter); // HTTP 1.1 OK writer.Write(_http11OK); diff --git a/benchmarkapps/PlatformBenchmarks/BufferWriterExtensions.cs b/benchmarkapps/PlatformBenchmarks/BufferWriterExtensions.cs deleted file mode 100644 index 0da6b24fd4..0000000000 --- a/benchmarkapps/PlatformBenchmarks/BufferWriterExtensions.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Buffers.Text; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Runtime.CompilerServices; - -namespace System.Buffers -{ - internal static class BufferWriterExtensions - { - private const int MaxULongByteLength = 20; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static void WriteNumeric(ref this BufferWriter buffer, ulong number) - { - // Try to format directly - if (Utf8Formatter.TryFormat(number, buffer.Span, out int bytesWritten)) - { - buffer.Advance(bytesWritten); - } - else - { - // Ask for at least 20 bytes - buffer.Ensure(MaxULongByteLength); - - Debug.Assert(buffer.Span.Length >= 20, "Buffer is < 20 bytes"); - - // Try again - if (Utf8Formatter.TryFormat(number, buffer.Span, out bytesWritten)) - { - buffer.Advance(bytesWritten); - } - } - } - } -} \ No newline at end of file diff --git a/benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj b/benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj index a883710816..4d4126c688 100644 --- a/benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj +++ b/benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj @@ -16,13 +16,11 @@ - - diff --git a/src/Connections.Abstractions/DefaultConnectionContext.cs b/src/Connections.Abstractions/DefaultConnectionContext.cs index 4b3377350d..1b161c7253 100644 --- a/src/Connections.Abstractions/DefaultConnectionContext.cs +++ b/src/Connections.Abstractions/DefaultConnectionContext.cs @@ -6,22 +6,25 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Security.Claims; using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Connections { public class DefaultConnectionContext : ConnectionContext, + IDisposable, IConnectionIdFeature, IConnectionItemsFeature, IConnectionTransportFeature, - IConnectionUserFeature + IConnectionUserFeature, + IConnectionLifetimeFeature { + private CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); + public DefaultConnectionContext() : this(Guid.NewGuid().ToString()) { + ConnectionClosed = _connectionClosedTokenSource.Token; } /// @@ -38,6 +41,7 @@ namespace Microsoft.AspNetCore.Connections Features.Set(this); Features.Set(this); Features.Set(this); + Features.Set(this); } public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) @@ -58,5 +62,17 @@ namespace Microsoft.AspNetCore.Connections public IDuplexPipe Application { get; set; } public override IDuplexPipe Transport { get; set; } + + public CancellationToken ConnectionClosed { get; set; } + + public virtual void Abort() + { + ThreadPool.QueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource); + } + + public void Dispose() + { + _connectionClosedTokenSource.Dispose(); + } } } diff --git a/src/Connections.Abstractions/Features/IConnectionLifetimeFeature.cs b/src/Connections.Abstractions/Features/IConnectionLifetimeFeature.cs new file mode 100644 index 0000000000..8f804de898 --- /dev/null +++ b/src/Connections.Abstractions/Features/IConnectionLifetimeFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionLifetimeFeature + { + CancellationToken ConnectionClosed { get; set; } + void Abort(); + } +} diff --git a/src/Kestrel.Core/Internal/Http/ChunkWriter.cs b/src/Kestrel.Core/Internal/Http/ChunkWriter.cs index 3d8cc4566b..2184937b07 100644 --- a/src/Kestrel.Core/Internal/Http/ChunkWriter.cs +++ b/src/Kestrel.Core/Internal/Http/ChunkWriter.cs @@ -48,14 +48,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return new ArraySegment(bytes, offset, 10 - offset); } - internal static int WriteBeginChunkBytes(ref BufferWriter start, int dataCount) + internal static int WriteBeginChunkBytes(ref CountingBufferWriter start, int dataCount) { var chunkSegment = BeginChunkBytes(dataCount); start.Write(new ReadOnlySpan(chunkSegment.Array, chunkSegment.Offset, chunkSegment.Count)); return chunkSegment.Count; } - internal static void WriteEndChunkBytes(ref BufferWriter start) + internal static void WriteEndChunkBytes(ref CountingBufferWriter start) { start.Write(new ReadOnlySpan(_endChunkBytes.Array, _endChunkBytes.Offset, _endChunkBytes.Count)); } diff --git a/src/Kestrel.Core/Internal/Http/CountingBufferWriter.cs b/src/Kestrel.Core/Internal/Http/CountingBufferWriter.cs new file mode 100644 index 0000000000..e3299c6027 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http/CountingBufferWriter.cs @@ -0,0 +1,99 @@ + +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + // TODO: Once this is public, update the actual CountingBufferWriter in the Common repo, + // and go back to using that. + internal ref struct CountingBufferWriter where T: IBufferWriter + { + private T _output; + private Span _span; + private int _buffered; + private long _bytesCommitted; + + public CountingBufferWriter(T output) + { + _buffered = 0; + _bytesCommitted = 0; + _output = output; + _span = output.GetSpan(); + } + + public Span Span => _span; + public long BytesCommitted => _bytesCommitted; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Commit() + { + var buffered = _buffered; + if (buffered > 0) + { + _bytesCommitted += buffered; + _buffered = 0; + _output.Advance(buffered); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Advance(int count) + { + _buffered += count; + _span = _span.Slice(count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Write(ReadOnlySpan source) + { + if (_span.Length >= source.Length) + { + source.CopyTo(_span); + Advance(source.Length); + } + else + { + WriteMultiBuffer(source); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Ensure(int count = 1) + { + if (_span.Length < count) + { + EnsureMore(count); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void EnsureMore(int count = 0) + { + if (_buffered > 0) + { + Commit(); + } + + _output.GetMemory(count); + _span = _output.GetSpan(); + } + + private void WriteMultiBuffer(ReadOnlySpan source) + { + while (source.Length > 0) + { + if (_span.Length == 0) + { + EnsureMore(); + } + + var writable = Math.Min(source.Length, _span.Length); + source.Slice(0, writable).CopyTo(_span); + source = source.Slice(writable); + Advance(writable); + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http/Http1Connection.cs b/src/Kestrel.Core/Internal/Http/Http1Connection.cs index 6fd58d08b0..f93a7997fa 100644 --- a/src/Kestrel.Core/Internal/Http/Http1Connection.cs +++ b/src/Kestrel.Core/Internal/Http/Http1Connection.cs @@ -12,6 +12,8 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { @@ -42,7 +44,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _keepAliveTicks = ServerOptions.Limits.KeepAliveTimeout.Ticks; _requestHeadersTimeoutTicks = ServerOptions.Limits.RequestHeadersTimeout.Ticks; - Output = new Http1OutputProducer(_context.Application.Input, _context.Transport.Output, _context.ConnectionId, _context.ServiceContext.Log, _context.TimeoutControl); + Output = new Http1OutputProducer( + _context.Application.Input, + _context.Transport.Output, + _context.ConnectionId, + _context.ServiceContext.Log, + _context.TimeoutControl, + _context.ConnectionFeatures.Get(), + _context.ConnectionFeatures.Get()); } public PipeReader Input => _context.Transport.Input; diff --git a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs index c47bf114b6..d65805cc83 100644 --- a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs @@ -89,6 +89,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } else if (result.IsCompleted) { + // Treat any FIN from an upgraded request as expected. + // It's up to higher-level consumer (i.e. WebSocket middleware) to determine + // if the end is actually expected based on higher-level framing. + if (RequestUpgrade) + { + break; + } + BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent); } diff --git a/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs b/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs index d767539185..4ef0dff24e 100644 --- a/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs @@ -8,7 +8,9 @@ using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { @@ -22,11 +24,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private readonly string _connectionId; private readonly ITimeoutControl _timeoutControl; private readonly IKestrelTrace _log; + private readonly IConnectionLifetimeFeature _lifetimeFeature; + private readonly IBytesWrittenFeature _transportBytesWrittenFeature; // This locks access to to all of the below fields private readonly object _contextLock = new object(); private bool _completed = false; + private bool _aborted; + private long _unflushedBytes; + private long _totalBytesCommitted; private readonly PipeWriter _pipeWriter; private readonly PipeReader _outputPipeReader; @@ -45,7 +52,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http PipeWriter pipeWriter, string connectionId, IKestrelTrace log, - ITimeoutControl timeoutControl) + ITimeoutControl timeoutControl, + IConnectionLifetimeFeature lifetimeFeature, + IBytesWrittenFeature transportBytesWrittenFeature) { _outputPipeReader = outputPipeReader; _pipeWriter = pipeWriter; @@ -53,6 +62,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _timeoutControl = timeoutControl; _log = log; _flushCompleted = OnFlushCompleted; + _lifetimeFeature = lifetimeFeature; + _transportBytesWrittenFeature = transportBytesWrittenFeature; } public Task WriteDataAsync(ReadOnlySpan buffer, CancellationToken cancellationToken = default(CancellationToken)) @@ -75,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return WriteAsync(Constants.EmptyData, cancellationToken); } - public void Write(Action callback, T state) + public void Write(Func callback, T state) { lock (_contextLock) { @@ -85,11 +96,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } var buffer = _pipeWriter; - callback(buffer, state); + var bytesCommitted = callback(buffer, state); + _unflushedBytes += bytesCommitted; + _totalBytesCommitted += bytesCommitted; } } - public Task WriteAsync(Action callback, T state) + public Task WriteAsync(Func callback, T state) { lock (_contextLock) { @@ -99,7 +112,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } var buffer = _pipeWriter; - callback(buffer, state); + var bytesCommitted = callback(buffer, state); + _unflushedBytes += bytesCommitted; + _totalBytesCommitted += bytesCommitted; } return FlushAsync(); @@ -115,14 +130,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } var buffer = _pipeWriter; - var writer = new BufferWriter(buffer); + var writer = new CountingBufferWriter(buffer); writer.Write(_bytesHttpVersion11); var statusBytes = ReasonPhrases.ToStatusBytes(statusCode, reasonPhrase); writer.Write(statusBytes); responseHeaders.CopyTo(ref writer); writer.Write(_bytesEndHeaders); + writer.Commit(); + + _unflushedBytes += writer.BytesCommitted; + _totalBytesCommitted += writer.BytesCommitted; } } @@ -138,23 +157,41 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _log.ConnectionDisconnect(_connectionId); _completed = true; _pipeWriter.Complete(); + + var unsentBytes = _totalBytesCommitted - _transportBytesWrittenFeature.TotalBytesWritten; + + if (unsentBytes > 0) + { + // unsentBytes should never be over 64KB in the default configuration. + _timeoutControl.StartTimingWrite((int)Math.Min(unsentBytes, int.MaxValue)); + _pipeWriter.OnReaderCompleted((ex, state) => ((ITimeoutControl)state).StopTimingWrite(), _timeoutControl); + } } } public void Abort(Exception error) { + // Abort can be called after Dispose if there's a flush timeout. + // It's important to still call _lifetimeFeature.Abort() in this case. + lock (_contextLock) { - if (_completed) + if (_aborted) { return; } - _log.ConnectionDisconnect(_connectionId); - _completed = true; + if (!_completed) + { + _log.ConnectionDisconnect(_connectionId); + _completed = true; - _outputPipeReader.CancelPendingRead(); - _pipeWriter.Complete(error); + _outputPipeReader.CancelPendingRead(); + _pipeWriter.Complete(error); + } + + _aborted = true; + _lifetimeFeature.Abort(); } } @@ -177,13 +214,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } writableBuffer = _pipeWriter; - var writer = new BufferWriter(writableBuffer); + var writer = new CountingBufferWriter(writableBuffer); if (buffer.Length > 0) { writer.Write(buffer); - bytesWritten += buffer.Length; + + _unflushedBytes += buffer.Length; + _totalBytesCommitted += buffer.Length; } writer.Commit(); + + bytesWritten = _unflushedBytes; + _unflushedBytes = 0; } return FlushAsync(writableBuffer, bytesWritten, cancellationToken); diff --git a/src/Kestrel.Core/Internal/Http/HttpHeaders.Generated.cs b/src/Kestrel.Core/Internal/Http/HttpHeaders.Generated.cs index d84f15706d..f81e87a06c 100644 --- a/src/Kestrel.Core/Internal/Http/HttpHeaders.Generated.cs +++ b/src/Kestrel.Core/Internal/Http/HttpHeaders.Generated.cs @@ -7765,7 +7765,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return true; } - internal void CopyToFast(ref BufferWriter output) + internal void CopyToFast(ref CountingBufferWriter output) { var tempBits = _bits | (_contentLength.HasValue ? -9223372036854775808L : 0); diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs index 055e8a1e76..3c53523f8a 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -8,6 +8,7 @@ using System.IO; using System.Net; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -282,7 +283,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http void IHttpRequestLifetimeFeature.Abort() { - Abort(error: null); + Abort(new ConnectionAbortedException()); } } } diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs index 7db0062a15..4e35490a79 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private static readonly byte[] _bytesConnectionKeepAlive = Encoding.ASCII.GetBytes("\r\nConnection: keep-alive"); private static readonly byte[] _bytesTransferEncodingChunked = Encoding.ASCII.GetBytes("\r\nTransfer-Encoding: chunked"); private static readonly byte[] _bytesServer = Encoding.ASCII.GetBytes("\r\nServer: " + Constants.ServerName); - private static readonly Action> _writeChunk = WriteChunk; + private static readonly Func, long> _writeChunk = WriteChunk; private readonly object _onStartingSync = new Object(); private readonly object _onCompletedSync = new Object(); @@ -411,21 +411,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } /// - /// Immediate kill the connection and poison the request and response streams. + /// Immediately kill the connection and poison the request and response streams with an error if there is one. /// public void Abort(Exception error) { - if (Interlocked.Exchange(ref _requestAborted, 1) == 0) + if (Interlocked.Exchange(ref _requestAborted, 1) != 0) { - _keepAlive = false; - - _streams?.Abort(error); - - Output.Abort(error); - - // Potentially calling user code. CancelRequestAbortedToken logs any exceptions. - ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this); + return; } + + _keepAlive = false; + + // If Abort() isn't called with an exception, there was a FIN. In this case, even though the connection is + // still closed immediately, we allow the app to drain the data in the request buffer. If the request data + // was truncated, MessageBody will complete the RequestBodyPipe with an error. + if (error != null) + { + _streams?.Abort(error); + } + + Output.Abort(error); + + // Potentially calling user code. CancelRequestAbortedToken logs any exceptions. + ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this); } public void OnHeader(Span name, Span value) @@ -474,6 +482,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { OnRequestProcessingEnding(); await TryProduceInvalidRequestResponse(); + + // Prevent RequestAborted from firing. + Reset(); + Output.Dispose(); } catch (Exception ex) @@ -911,16 +923,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return Output.WriteAsync(_writeChunk, data); } - private static void WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory buffer) + private static long WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory buffer) { - var writer = new BufferWriter(writableBuffer); + var bytesWritten = 0L; if (buffer.Length > 0) { + var writer = new CountingBufferWriter(writableBuffer); + ChunkWriter.WriteBeginChunkBytes(ref writer, buffer.Length); writer.Write(buffer.Span); ChunkWriter.WriteEndChunkBytes(ref writer); writer.Commit(); + + bytesWritten = writer.BytesCommitted; } + + return bytesWritten; } private static ArraySegment CreateAsciiByteArraySegment(string text) diff --git a/src/Kestrel.Core/Internal/Http/HttpResponseHeaders.cs b/src/Kestrel.Core/Internal/Http/HttpResponseHeaders.cs index 1df80f3dc6..a4b81cf69a 100644 --- a/src/Kestrel.Core/Internal/Http/HttpResponseHeaders.cs +++ b/src/Kestrel.Core/Internal/Http/HttpResponseHeaders.cs @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return GetEnumerator(); } - internal void CopyTo(ref BufferWriter buffer) + internal void CopyTo(ref CountingBufferWriter buffer) { CopyToFast(ref buffer); if (MaybeUnknown != null) diff --git a/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs b/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs index abc2b4454a..0c45253bdc 100644 --- a/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs @@ -11,7 +11,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public interface IHttpOutputProducer : IDisposable { void Abort(Exception error); - Task WriteAsync(Action callback, T state); + Task WriteAsync(Func callback, T state); Task FlushAsync(CancellationToken cancellationToken); Task Write100ContinueAsync(CancellationToken cancellationToken); void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders); diff --git a/src/Kestrel.Core/Internal/Http/PipelineExtensions.cs b/src/Kestrel.Core/Internal/Http/PipelineExtensions.cs index e56c43b23c..fce822b6c5 100644 --- a/src/Kestrel.Core/Internal/Http/PipelineExtensions.cs +++ b/src/Kestrel.Core/Internal/Http/PipelineExtensions.cs @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return result; } - internal static unsafe void WriteAsciiNoValidation(ref this BufferWriter buffer, string data) + internal static unsafe void WriteAsciiNoValidation(ref this CountingBufferWriter buffer, string data) { if (string.IsNullOrEmpty(data)) { @@ -69,7 +69,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static unsafe void WriteNumeric(ref this BufferWriter buffer, ulong number) + internal static unsafe void WriteNumeric(ref this CountingBufferWriter buffer, ulong number) { const byte AsciiDigitStart = (byte)'0'; @@ -119,7 +119,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } [MethodImpl(MethodImplOptions.NoInlining)] - private static void WriteNumericMultiWrite(ref this BufferWriter buffer, ulong number) + private static void WriteNumericMultiWrite(ref this CountingBufferWriter buffer, ulong number) { const byte AsciiDigitStart = (byte)'0'; @@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } [MethodImpl(MethodImplOptions.NoInlining)] - private unsafe static void WriteAsciiMultiWrite(ref this BufferWriter buffer, string data) + private unsafe static void WriteAsciiMultiWrite(ref this CountingBufferWriter buffer, string data) { var remaining = data.Length; diff --git a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs index 571f5a32d2..0fe1fbf344 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 // TODO: RST_STREAM? } - public Task WriteAsync(Action callback, T state) + public Task WriteAsync(Func callback, T state) { throw new NotImplementedException(); } diff --git a/src/Kestrel.Core/Internal/HttpConnection.cs b/src/Kestrel.Core/Internal/HttpConnection.cs index d57a7f920c..5673485d77 100644 --- a/src/Kestrel.Core/Internal/HttpConnection.cs +++ b/src/Kestrel.Core/Internal/HttpConnection.cs @@ -370,6 +370,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal public void Tick(DateTimeOffset now) { + if (_protocolSelectionState == ProtocolSelectionState.Stopped) + { + // It's safe to check for timeouts on a dead connection, + // but try not to in order to avoid extraneous logs. + return; + } + var timestamp = now.Ticks; CheckForTimeout(timestamp); @@ -554,17 +561,27 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal if (minResponseDataRate != null) { - var timeoutTicks = Math.Max( + // Add Heartbeat.Interval since this can be called right before the next heartbeat. + var currentTimeUpperBound = _lastTimestamp + Heartbeat.Interval.Ticks; + var ticksToCompleteWriteAtMinRate = TimeSpan.FromSeconds(size / minResponseDataRate.BytesPerSecond).Ticks; + + // If ticksToCompleteWriteAtMinRate is less than the configured grace period, + // allow that write to take up to the grace period to complete. Only add the grace period + // to the current time and not to any accumulated timeout. + var singleWriteTimeoutTimestamp = currentTimeUpperBound + Math.Max( minResponseDataRate.GracePeriod.Ticks, - TimeSpan.FromSeconds(size / minResponseDataRate.BytesPerSecond).Ticks); + ticksToCompleteWriteAtMinRate); - if (_writeTimingWrites == 0) - { - // Add Heartbeat.Interval since this can be called right before the next heartbeat. - _writeTimingTimeoutTimestamp = _lastTimestamp + Heartbeat.Interval.Ticks; - } + // Don't penalize a connection for completing previous writes more quickly than required. + // We don't want to kill a connection when flushing the chunk terminator just because the previous + // chunk was large if the previous chunk was flushed quickly. - _writeTimingTimeoutTimestamp += timeoutTicks; + // Don't add any grace period to this accumulated timeout because the grace period could + // get accumulated repeatedly making the timeout for a bunch of consecutive small writes + // far too conservative. + var accumulatedWriteTimeoutTimestamp = _writeTimingTimeoutTimestamp + ticksToCompleteWriteAtMinRate; + + _writeTimingTimeoutTimestamp = Math.Max(singleWriteTimeoutTimestamp, accumulatedWriteTimeoutTimestamp); _writeTimingWrites++; } } diff --git a/src/Kestrel.Core/Properties/AssemblyInfo.cs b/src/Kestrel.Core/Properties/AssemblyInfo.cs index 27495d5268..6898d541a6 100644 --- a/src/Kestrel.Core/Properties/AssemblyInfo.cs +++ b/src/Kestrel.Core/Properties/AssemblyInfo.cs @@ -10,3 +10,4 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("Kestrel.Performance, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: InternalsVisibleTo("Http2SampleApp, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("PlatformBenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Kestrel.Transport.Abstractions/Internal/IBytesWrittenFeature.cs b/src/Kestrel.Transport.Abstractions/Internal/IBytesWrittenFeature.cs new file mode 100644 index 0000000000..e4bf998f37 --- /dev/null +++ b/src/Kestrel.Transport.Abstractions/Internal/IBytesWrittenFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface IBytesWrittenFeature + { + long TotalBytesWritten { get; } + } +} diff --git a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs index 8ab96ba108..4b65762f30 100644 --- a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs +++ b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs @@ -4,8 +4,9 @@ using System.Collections; using System.Collections.Generic; using System.IO.Pipelines; using System.Net; -using Microsoft.AspNetCore.Http.Features; +using System.Threading; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { @@ -16,7 +17,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal IConnectionItemsFeature, IMemoryPoolFeature, IApplicationTransportFeature, - ITransportSchedulerFeature + ITransportSchedulerFeature, + IConnectionLifetimeFeature, + IBytesWrittenFeature { private static readonly Type IHttpConnectionFeatureType = typeof(IHttpConnectionFeature); private static readonly Type IConnectionIdFeatureType = typeof(IConnectionIdFeature); @@ -25,6 +28,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal private static readonly Type IMemoryPoolFeatureType = typeof(IMemoryPoolFeature); private static readonly Type IApplicationTransportFeatureType = typeof(IApplicationTransportFeature); private static readonly Type ITransportSchedulerFeatureType = typeof(ITransportSchedulerFeature); + private static readonly Type IConnectionLifetimeFeatureType = typeof(IConnectionLifetimeFeature); + private static readonly Type IBytesWrittenFeatureType = typeof(IBytesWrittenFeature); private object _currentIHttpConnectionFeature; private object _currentIConnectionIdFeature; @@ -33,6 +38,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal private object _currentIMemoryPoolFeature; private object _currentIApplicationTransportFeature; private object _currentITransportSchedulerFeature; + private object _currentIConnectionLifetimeFeature; + private object _currentIBytesWrittenFeature; private int _featureRevision; @@ -127,6 +134,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal set => Items = value; } + CancellationToken IConnectionLifetimeFeature.ConnectionClosed + { + get => ConnectionClosed; + set => ConnectionClosed = value; + } + + void IConnectionLifetimeFeature.Abort() => Abort(); + + long IBytesWrittenFeature.TotalBytesWritten => TotalBytesWritten; + PipeScheduler ITransportSchedulerFeature.InputWriterScheduler => InputWriterScheduler; PipeScheduler ITransportSchedulerFeature.OutputReaderScheduler => OutputReaderScheduler; @@ -169,6 +186,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal return _currentITransportSchedulerFeature; } + if (key == IConnectionLifetimeFeatureType) + { + return _currentIConnectionLifetimeFeature; + } + + if (key == IBytesWrittenFeatureType) + { + return _currentIBytesWrittenFeature; + } + if (MaybeExtra != null) { return ExtraFeatureGet(key); @@ -208,6 +235,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { _currentITransportSchedulerFeature = value; } + else if (key == IConnectionLifetimeFeatureType) + { + _currentIConnectionLifetimeFeature = value; + } + else if (key == IBytesWrittenFeatureType) + { + _currentIBytesWrittenFeature = value; + } else { ExtraFeatureSet(key, value); @@ -245,6 +280,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { return (TFeature)_currentITransportSchedulerFeature; } + else if (typeof(TFeature) == typeof(IConnectionLifetimeFeature)) + { + return (TFeature)_currentIConnectionLifetimeFeature; + } + else if (typeof(TFeature) == typeof(IBytesWrittenFeature)) + { + return (TFeature)_currentIBytesWrittenFeature; + } else if (MaybeExtra != null) { return (TFeature)ExtraFeatureGet(typeof(TFeature)); @@ -285,6 +328,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { _currentITransportSchedulerFeature = instance; } + else if (typeof(TFeature) == typeof(IConnectionLifetimeFeature)) + { + _currentIConnectionLifetimeFeature = instance; + } + else if (typeof(TFeature) == typeof(IBytesWrittenFeature)) + { + _currentIBytesWrittenFeature = instance; + } else { ExtraFeatureSet(typeof(TFeature), instance); @@ -332,6 +383,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal yield return new KeyValuePair(ITransportSchedulerFeatureType, _currentITransportSchedulerFeature); } + if (_currentIConnectionLifetimeFeature != null) + { + yield return new KeyValuePair(IConnectionLifetimeFeatureType, _currentIConnectionLifetimeFeature); + } + + if (_currentIBytesWrittenFeature != null) + { + yield return new KeyValuePair(IBytesWrittenFeatureType, _currentIBytesWrittenFeature); + } + if (MaybeExtra != null) { foreach (var item in MaybeExtra) diff --git a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs index c3d014ea57..5791033da6 100644 --- a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs +++ b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs @@ -8,7 +8,7 @@ using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { - public abstract partial class TransportConnection : ConnectionContext + public partial class TransportConnection : ConnectionContext { private IDictionary _items; @@ -21,6 +21,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal _currentIApplicationTransportFeature = this; _currentIMemoryPoolFeature = this; _currentITransportSchedulerFeature = this; + _currentIConnectionLifetimeFeature = this; + _currentIBytesWrittenFeature = this; } public IPAddress RemoteAddress { get; set; } @@ -35,6 +37,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal public virtual MemoryPool MemoryPool { get; } public virtual PipeScheduler InputWriterScheduler { get; } public virtual PipeScheduler OutputReaderScheduler { get; } + public virtual long TotalBytesWritten { get; } public override IDuplexPipe Transport { get; set; } public IDuplexPipe Application { get; set; } @@ -54,5 +57,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal public PipeWriter Input => Application.Output; public PipeReader Output => Application.Input; + + public CancellationToken ConnectionClosed { get; set; } + + public virtual void Abort() + { + } } } diff --git a/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs b/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs index 52c0fc6d03..4ba4e94152 100644 --- a/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs +++ b/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs @@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal (handle, suggestedsize, state) => AllocCallback(handle, suggestedsize, state); private readonly UvStreamHandle _socket; + private readonly CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); private MemoryHandle _bufferHandle; @@ -42,6 +43,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal LocalAddress = localEndPoint.Address; LocalPort = localEndPoint.Port; + + ConnectionClosed = _connectionClosedTokenSource.Token; } Log = log; @@ -55,6 +58,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal public override PipeScheduler InputWriterScheduler => Thread; public override PipeScheduler OutputReaderScheduler => Thread; + public override long TotalBytesWritten => OutputConsumer?.TotalBytesWritten ?? 0; + public async Task Start() { try @@ -91,6 +96,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal // We're done with the socket now _socket.Dispose(); + ThreadPool.QueueUserWorkItem(state => ((LibuvConnection)state).CancelConnectionClosedToken(), this); } } catch (Exception e) @@ -99,6 +105,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal } } + public override void Abort() + { + // This cancels any pending I/O. + Thread.Post(s => s.Dispose(), _socket); + } + // Called on Libuv thread private static LibuvFunctions.uv_buf_t AllocCallback(UvStreamHandle handle, int suggestedSize, object state) { @@ -205,5 +217,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal Input.Complete(error); } } + + private void CancelConnectionClosedToken() + { + try + { + _connectionClosedTokenSource.Cancel(); + _connectionClosedTokenSource.Dispose(); + } + catch (Exception ex) + { + Log.LogError(0, ex, $"Unexpected exception in {nameof(LibuvConnection)}.{nameof(CancelConnectionClosedToken)}."); + } + } } } diff --git a/src/Kestrel.Transport.Libuv/Internal/LibuvOutputConsumer.cs b/src/Kestrel.Transport.Libuv/Internal/LibuvOutputConsumer.cs index cb60e9a2e0..94195a3ffc 100644 --- a/src/Kestrel.Transport.Libuv/Internal/LibuvOutputConsumer.cs +++ b/src/Kestrel.Transport.Libuv/Internal/LibuvOutputConsumer.cs @@ -3,6 +3,7 @@ using System; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; @@ -16,6 +17,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal private readonly ILibuvTrace _log; private readonly PipeReader _pipe; + private long _totalBytesWritten; + public LibuvOutputConsumer( PipeReader pipe, LibuvThread thread, @@ -28,10 +31,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal _socket = socket; _connectionId = connectionId; _log = log; - - _pipe.OnWriterCompleted(OnWriterCompleted, this); } + public long TotalBytesWritten => Interlocked.Read(ref _totalBytesWritten); + public async Task WriteOutputAsync() { var pool = _thread.WriteReqPool; @@ -46,7 +49,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal } catch { - // Handled in OnWriterCompleted + // Handled in LibuvConnection.Abort() return; } @@ -73,6 +76,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal var writeResult = await writeReq.WriteAsync(_socket, buffer); + // This is not interlocked because there could be a concurrent writer. + // Instead it's to prevent read tearing on 32-bit systems. + Interlocked.Add(ref _totalBytesWritten, buffer.Length); + LogWriteInfo(writeResult.Status, writeResult.Error); if (writeResult.Error != null) @@ -85,6 +92,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal { // Make sure we return the writeReq to the pool pool.Return(writeReq); + + // Null out writeReq so it doesn't get caught by CheckUvReqLeaks. + // It is rooted by a TestSink scope through Pipe continuations in + // ResponseTests.HttpsConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate + writeReq = null; } } else if (result.IsCompleted) @@ -99,16 +111,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal } } - private static void OnWriterCompleted(Exception ex, object state) - { - // Cut off writes if the writer is completed with an error. If a write request is pending, this will cancel it. - if (ex != null) - { - var libuvOutputConsumer = (LibuvOutputConsumer)state; - libuvOutputConsumer._socket.Dispose(); - } - } - private void LogWriteInfo(int status, Exception error) { if (error == null) diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs index 5582da73fc..59742c2e51 100644 --- a/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs @@ -9,6 +9,7 @@ using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; @@ -26,8 +27,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal private readonly ISocketsTrace _trace; private readonly SocketReceiver _receiver; private readonly SocketSender _sender; + private readonly CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); + private readonly object _shutdownLock = new object(); private volatile bool _aborted; + private long _totalBytesWritten; internal SocketConnection(Socket socket, MemoryPool memoryPool, PipeScheduler scheduler, ISocketsTrace trace) { @@ -49,6 +53,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal RemoteAddress = remoteEndPoint.Address; RemotePort = remoteEndPoint.Port; + ConnectionClosed = _connectionClosedTokenSource.Token; + // On *nix platforms, Sockets already dispatches to the ThreadPool. var awaiterScheduler = IsWindows ? _scheduler : PipeScheduler.Inline; @@ -59,6 +65,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal public override MemoryPool MemoryPool { get; } public override PipeScheduler InputWriterScheduler => _scheduler; public override PipeScheduler OutputReaderScheduler => _scheduler; + public override long TotalBytesWritten => Interlocked.Read(ref _totalBytesWritten); public async Task StartAsync() { @@ -86,6 +93,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal _socket.Dispose(); _receiver.Dispose(); _sender.Dispose(); + ThreadPool.QueueUserWorkItem(state => ((SocketConnection)state).CancelConnectionClosedToken(), this); } catch (Exception ex) { @@ -98,6 +106,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal } } + public override void Abort() + { + // Try to gracefully close the socket to match libuv behavior. + Shutdown(); + _socket.Dispose(); + } + private async Task DoReceive() { Exception error = null; @@ -214,15 +229,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal { error = new IOException(ex.Message, ex); } - finally - { - // Make sure to close the connection only after the _aborted flag is set. - // Without this, the RequestsCanBeAbortedMidRead test will sometimes fail when - // a BadHttpRequestException is thrown instead of a TaskCanceledException. - _aborted = true; - _trace.ConnectionWriteFin(ConnectionId); - _socket.Shutdown(SocketShutdown.Both); - } + + Shutdown(); return error; } @@ -247,6 +255,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal await _sender.SendAsync(buffer); } + // This is not interlocked because there could be a concurrent writer. + // Instead it's to prevent read tearing on 32-bit systems. + Interlocked.Add(ref _totalBytesWritten, buffer.Length); + Output.AdvanceTo(end); if (isCompleted) @@ -255,5 +267,36 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal } } } + + private void Shutdown() + { + lock (_shutdownLock) + { + if (!_aborted) + { + // Make sure to close the connection only after the _aborted flag is set. + // Without this, the RequestsCanBeAbortedMidRead test will sometimes fail when + // a BadHttpRequestException is thrown instead of a TaskCanceledException. + _aborted = true; + _trace.ConnectionWriteFin(ConnectionId); + + // Try to gracefully close the socket even for aborts to match libuv behavior. + _socket.Shutdown(SocketShutdown.Both); + } + } + } + + private void CancelConnectionClosedToken() + { + try + { + _connectionClosedTokenSource.Cancel(); + _connectionClosedTokenSource.Dispose(); + } + catch (Exception ex) + { + _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(CancelConnectionClosedToken)}."); + } + } } } diff --git a/test/Kestrel.Core.Tests/ConnectionDispatcherTests.cs b/test/Kestrel.Core.Tests/ConnectionDispatcherTests.cs index 9f53e16750..24f8ded379 100644 --- a/test/Kestrel.Core.Tests/ConnectionDispatcherTests.cs +++ b/test/Kestrel.Core.Tests/ConnectionDispatcherTests.cs @@ -1,13 +1,10 @@ -using System.Buffers; +using System; +using System.Buffers; using System.Collections.Generic; using System.IO.Pipelines; using System.Linq; -using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; using Microsoft.AspNetCore.Testing; using Xunit; @@ -23,7 +20,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var tcs = new TaskCompletionSource(); var dispatcher = new ConnectionDispatcher(serviceContext, _ => tcs.Task); - var connection = new TestConnection(); + var connection = new TransportConnection(); dispatcher.OnConnection(connection); @@ -44,14 +41,5 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests // Verify the scope was disposed after request processing completed Assert.True(((TestKestrelTrace)serviceContext.Log).Logger.Scopes.IsEmpty); } - - private class TestConnection : TransportConnection - { - public override MemoryPool MemoryPool { get; } = KestrelMemoryPool.Create(); - - public override PipeScheduler InputWriterScheduler => PipeScheduler.ThreadPool; - - public override PipeScheduler OutputReaderScheduler => PipeScheduler.ThreadPool; - } } } diff --git a/test/Kestrel.Core.Tests/Http1ConnectionTests.cs b/test/Kestrel.Core.Tests/Http1ConnectionTests.cs index caf367d5a3..b1dc3b91aa 100644 --- a/test/Kestrel.Core.Tests/Http1ConnectionTests.cs +++ b/test/Kestrel.Core.Tests/Http1ConnectionTests.cs @@ -11,6 +11,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -48,12 +49,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests _transport = pair.Transport; _application = pair.Application; + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + _serviceContext = new TestServiceContext(); _timeoutControl = new Mock(); _http1ConnectionContext = new Http1ConnectionContext { ServiceContext = _serviceContext, - ConnectionFeatures = new FeatureCollection(), + ConnectionFeatures = connectionFeatures, MemoryPool = _pipelineFactory, TimeoutControl = _timeoutControl.Object, Application = pair.Application, @@ -727,8 +732,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); - Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); - await requestProcessingTask.TimeoutAfter(TestConstants.DefaultTimeout); } @@ -767,9 +770,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); - Assert.Same(newRequestHeaders, _http1Connection.RequestHeaders); - Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); - await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); } diff --git a/test/Kestrel.Core.Tests/HttpConnectionTests.cs b/test/Kestrel.Core.Tests/HttpConnectionTests.cs index 8710055844..85243bd7ff 100644 --- a/test/Kestrel.Core.Tests/HttpConnectionTests.cs +++ b/test/Kestrel.Core.Tests/HttpConnectionTests.cs @@ -6,6 +6,8 @@ using System.Buffers; using System.Collections.Generic; using System.IO.Pipelines; using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; @@ -29,11 +31,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); var pair = DuplexPipe.CreateConnectionPair(options, options); + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + _httpConnectionContext = new HttpConnectionContext { ConnectionId = "0123456789", ConnectionAdapters = new List(), - ConnectionFeatures = new FeatureCollection(), + ConnectionFeatures = connectionFeatures, MemoryPool = _memoryPool, HttpConnectionId = long.MinValue, Application = pair.Application, @@ -531,5 +537,56 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.True(_httpConnection.RequestTimedOut); Assert.True(aborted.Wait(TimeSpan.FromSeconds(10))); } + + [Fact] + public async Task WriteTimingAbortsConnectionWhenRepeadtedSmallWritesDoNotCompleteWithMinimumDataRate() + { + var systemClock = new MockSystemClock(); + var minResponseDataRate = new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(5)); + var numWrites = 5; + var writeSize = 100; + var aborted = new TaskCompletionSource(); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinResponseDataRate = minResponseDataRate; + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + _httpConnection.Http1Connection.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + // Initialize timestamp + var startTime = systemClock.UtcNow; + _httpConnection.Tick(startTime); + + // 5 consecutive 100 byte writes. + for (var i = 0; i < numWrites - 1; i++) + { + _httpConnection.StartTimingWrite(writeSize); + _httpConnection.StopTimingWrite(); + } + + // Stall the last write. + _httpConnection.StartTimingWrite(writeSize); + + // Move the clock forward Heartbeat.Interval + MinDataRate.GracePeriod + 4 seconds. + // The grace period should only be added for the first write. The subsequent 4 100 byte writes should add 1 second each to the timeout given the 100 byte/s min rate. + systemClock.UtcNow += Heartbeat.Interval + minResponseDataRate.GracePeriod + TimeSpan.FromSeconds((numWrites - 1) * writeSize / minResponseDataRate.BytesPerSecond); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.False(_httpConnection.RequestTimedOut); + + // On more tick forward triggers the timeout. + systemClock.UtcNow += TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.True(_httpConnection.RequestTimedOut); + await aborted.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + } } } diff --git a/test/Kestrel.Core.Tests/MessageBodyTests.cs b/test/Kestrel.Core.Tests/MessageBodyTests.cs index 9dccccba40..b92f429400 100644 --- a/test/Kestrel.Core.Tests/MessageBodyTests.cs +++ b/test/Kestrel.Core.Tests/MessageBodyTests.cs @@ -786,7 +786,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests input.Fin(); - await Assert.ThrowsAsync(async () => await body.ReadAsync(new ArraySegment(new byte[1]))); + Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); mockTimeoutControl.Verify(timeoutControl => timeoutControl.StartTimingReads(), Times.Never); mockTimeoutControl.Verify(timeoutControl => timeoutControl.StopTimingReads(), Times.Never); diff --git a/test/Kestrel.Core.Tests/OutputProducerTests.cs b/test/Kestrel.Core.Tests/OutputProducerTests.cs index dc063922ff..3f9750a11e 100644 --- a/test/Kestrel.Core.Tests/OutputProducerTests.cs +++ b/test/Kestrel.Core.Tests/OutputProducerTests.cs @@ -5,6 +5,7 @@ using System; using System.Buffers; using System.IO.Pipelines; using System.Threading; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; @@ -49,6 +50,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests socketOutput.Write((buffer, state) => { called = true; + return 0; }, 0); @@ -56,8 +58,33 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } - private Http1OutputProducer CreateOutputProducer(PipeOptions pipeOptions) + [Fact] + public void AbortsTransportEvenAfterDispose() { + var mockLifetimeFeature = new Mock(); + + var outputProducer = CreateOutputProducer(lifetimeFeature: mockLifetimeFeature.Object); + + outputProducer.Dispose(); + + mockLifetimeFeature.Verify(f => f.Abort(), Times.Never()); + + outputProducer.Abort(null); + + mockLifetimeFeature.Verify(f => f.Abort(), Times.Once()); + + outputProducer.Abort(null); + + mockLifetimeFeature.Verify(f => f.Abort(), Times.Once()); + } + + private Http1OutputProducer CreateOutputProducer( + PipeOptions pipeOptions = null, + IConnectionLifetimeFeature lifetimeFeature = null) + { + pipeOptions = pipeOptions ?? new PipeOptions(); + lifetimeFeature = lifetimeFeature ?? Mock.Of(); + var pipe = new Pipe(pipeOptions); var serviceContext = new TestServiceContext(); var socketOutput = new Http1OutputProducer( @@ -65,7 +92,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests pipe.Writer, "0", serviceContext.Log, - Mock.Of()); + Mock.Of(), + lifetimeFeature, + Mock.Of()); return socketOutput; } diff --git a/test/Kestrel.Core.Tests/PipelineExtensionTests.cs b/test/Kestrel.Core.Tests/PipelineExtensionTests.cs index 11f7209ac2..2e09a3d2cf 100644 --- a/test/Kestrel.Core.Tests/PipelineExtensionTests.cs +++ b/test/Kestrel.Core.Tests/PipelineExtensionTests.cs @@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public void WritesNumericToAscii(ulong number) { var writerBuffer = _pipe.Writer; - var writer = new BufferWriter(writerBuffer); + var writer = new CountingBufferWriter(writerBuffer); writer.WriteNumeric(number); writer.Commit(); writerBuffer.FlushAsync().GetAwaiter().GetResult(); @@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public void WritesNumericAcrossSpanBoundaries(int gapSize) { var writerBuffer = _pipe.Writer; - var writer = new BufferWriter(writerBuffer); + var writer = new CountingBufferWriter(writerBuffer); // almost fill up the first block var spacer = new byte[writer.Span.Length - gapSize]; writer.Write(spacer); @@ -87,7 +87,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public void EncodesAsAscii(string input, byte[] expected) { var pipeWriter = _pipe.Writer; - var writer = new BufferWriter(pipeWriter); + var writer = new CountingBufferWriter(pipeWriter); writer.WriteAsciiNoValidation(input); writer.Commit(); pipeWriter.FlushAsync().GetAwaiter().GetResult(); @@ -117,7 +117,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests // WriteAscii doesn't validate if characters are in the ASCII range // but it shouldn't produce more than one byte per character var writerBuffer = _pipe.Writer; - var writer = new BufferWriter(writerBuffer); + var writer = new CountingBufferWriter(writerBuffer); writer.WriteAsciiNoValidation(input); writer.Commit(); writerBuffer.FlushAsync().GetAwaiter().GetResult(); @@ -131,7 +131,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests { const byte maxAscii = 0x7f; var writerBuffer = _pipe.Writer; - var writer = new BufferWriter(writerBuffer); + var writer = new CountingBufferWriter(writerBuffer); for (var i = 0; i < maxAscii; i++) { writer.WriteAsciiNoValidation(new string((char)i, 1)); @@ -161,7 +161,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests { var testString = new string(' ', stringLength); var writerBuffer = _pipe.Writer; - var writer = new BufferWriter(writerBuffer); + var writer = new CountingBufferWriter(writerBuffer); // almost fill up the first block var spacer = new byte[writer.Span.Length - gapSize]; writer.Write(spacer); diff --git a/test/Kestrel.Core.Tests/TestInput.cs b/test/Kestrel.Core.Tests/TestInput.cs index 1916b62753..faa40411c2 100644 --- a/test/Kestrel.Core.Tests/TestInput.cs +++ b/test/Kestrel.Core.Tests/TestInput.cs @@ -6,6 +6,7 @@ using System.Buffers; using System.IO.Pipelines; using System.Text; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; @@ -27,10 +28,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Transport = pair.Transport; Application = pair.Application; + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + Http1ConnectionContext = new Http1ConnectionContext { ServiceContext = new TestServiceContext(), - ConnectionFeatures = new FeatureCollection(), + ConnectionFeatures = connectionFeatures, Application = Application, Transport = Transport, MemoryPool = _memoryPool, diff --git a/test/Kestrel.FunctionalTests/HttpsTests.cs b/test/Kestrel.FunctionalTests/HttpsTests.cs index 24cc86ef3a..e113baab82 100644 --- a/test/Kestrel.FunctionalTests/HttpsTests.cs +++ b/test/Kestrel.FunctionalTests/HttpsTests.cs @@ -234,7 +234,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); await sslStream.WriteAsync(request, 0, request.Length); - await sslStream.ReadAsync(new byte[32], 0, 32); + + // Temporary workaround for a deadlock when reading from an aborted client SslStream on Mac and Linux. + if (TestPlatformHelper.IsWindows) + { + await sslStream.ReadAsync(new byte[32], 0, 32); + } + else + { + await stream.ReadAsync(new byte[32], 0, 32); + } } } @@ -285,7 +294,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); await sslStream.WriteAsync(request, 0, request.Length); - await sslStream.ReadAsync(new byte[32], 0, 32); + + // Temporary workaround for a deadlock when reading from an aborted client SslStream on Mac and Linux. + if (TestPlatformHelper.IsWindows) + { + await sslStream.ReadAsync(new byte[32], 0, 32); + } + else + { + await stream.ReadAsync(new byte[32], 0, 32); + } } } diff --git a/test/Kestrel.FunctionalTests/RequestTests.cs b/test/Kestrel.FunctionalTests/RequestTests.cs index 3b6c2fc0d5..00cb93198b 100644 --- a/test/Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Kestrel.FunctionalTests/RequestTests.cs @@ -15,10 +15,11 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -571,6 +572,108 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnClientFIN(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var appStartedTcs = new TaskCompletionSource(); + var connectionClosedTcs = new TaskCompletionSource(); + + using (var server = new TestServer(context => + { + appStartedTcs.SetResult(null); + + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await appStartedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + + connection.Socket.Shutdown(SocketShutdown.Send); + + await connectionClosedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnServerFIN(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var connectionClosedTcs = new TaskCompletionSource(); + + using (var server = new TestServer(context => + { + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + await connectionClosedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + + await connection.ReceiveEnd($"HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnServerAbort(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var connectionClosedTcs = new TaskCompletionSource(); + + using (var server = new TestServer(context => + { + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + context.Abort(); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await connectionClosedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + await connection.ReceiveForcedEnd(); + } + } + } + [Theory] [InlineData("http://localhost/abs/path", "/abs/path", null)] [InlineData("https://localhost/abs/path", "/abs/path", null)] // handles mismatch scheme diff --git a/test/Kestrel.FunctionalTests/ResponseTests.cs b/test/Kestrel.FunctionalTests/ResponseTests.cs index 9e228140cc..0f2dc8a9cd 100644 --- a/test/Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Kestrel.FunctionalTests/ResponseTests.cs @@ -21,7 +21,6 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Https; @@ -1704,6 +1703,50 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests Assert.True(foundMessage, "Expected log not found"); } + [Fact] + public async Task Sending100ContinueDoesNotPreventAutomatic400Responses() + { + using (var server = new TestServer(httpContext => + { + return httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "Expect: 100-continue", + "", + ""); + + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + + // Send an invalid chunk prefix to cause an error. + await connection.Send( + "gg"); + + // If 100 Continue sets HttpProtocol.HasResponseStarted to true, + // a success response will be produced before the server sees the + // bad chunk header above, making this test fail. + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Contains(TestApplicationErrorLogger.Messages, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + [Fact] public async Task Sending100ContinueAndResponseSendsChunkTerminatorBeforeConsumingRequestBody() { @@ -2587,21 +2630,27 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public void ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate() + public async Task ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate() { using (StartLog(out var loggerFactory, "ConnClosedWhenRespDoesNotSatisfyMin")) { var logger = loggerFactory.CreateLogger($"{ typeof(ResponseTests).FullName}.{ nameof(ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate)}"); - var chunkSize = 64 * 1024; - var chunks = 128; + const int chunkSize = 1024; + const int chunks = 256 * 1024; var responseSize = chunks * chunkSize; + var chunkData = new byte[chunkSize]; - var requestAborted = new ManualResetEventSlim(); - var messageLogged = new ManualResetEventSlim(); - var mockKestrelTrace = new Mock(loggerFactory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel")) { CallBase = true }; + var responseRateTimeoutMessageLogged = new TaskCompletionSource(); + var connectionStopMessageLogged = new TaskCompletionSource(); + var requestAborted = new TaskCompletionSource(); + + var mockKestrelTrace = new Mock(); mockKestrelTrace .Setup(trace => trace.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny())) - .Callback(() => messageLogged.Set()); + .Callback(() => responseRateTimeoutMessageLogged.SetResult(null)); + mockKestrelTrace + .Setup(trace => trace.ConnectionStop(It.IsAny())) + .Callback(() => connectionStopMessageLogged.SetResult(null));; var testContext = new TestServiceContext { @@ -2624,57 +2673,41 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests async Task App(HttpContext context) { appLogger.LogInformation("Request received"); - context.RequestAborted.Register(() => requestAborted.Set()); + context.RequestAborted.Register(() => requestAborted.SetResult(null)); context.Response.ContentLength = responseSize; for (var i = 0; i < chunks; i++) { - await context.Response.WriteAsync(new string('a', chunkSize), context.RequestAborted); + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); appLogger.LogInformation("Wrote chunk of {chunkSize} bytes", chunkSize); } } using (var server = new TestServer(App, testContext, listenOptions)) { - using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var connection = server.CreateConnection()) { - socket.ReceiveBufferSize = 1; - - socket.Connect(new IPEndPoint(IPAddress.Loopback, server.Port)); logger.LogInformation("Sending request"); - socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost: \r\n\r\n")); + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + logger.LogInformation("Sent request"); var sw = Stopwatch.StartNew(); logger.LogInformation("Waiting for connection to abort."); - Assert.True(messageLogged.Wait(TimeSpan.FromSeconds(120)), "The expected message was not logged within the timeout period."); - Assert.True(requestAborted.Wait(TimeSpan.FromSeconds(120)), "The request was not aborted within the timeout period."); + + await requestAborted.Task.TimeoutAfter(TimeSpan.FromSeconds(60)); + await responseRateTimeoutMessageLogged.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + await connectionStopMessageLogged.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + + await AssertStreamAborted(connection.Reader.BaseStream, chunkSize * chunks); + sw.Stop(); logger.LogInformation("Connection was aborted after {totalMilliseconds}ms.", sw.ElapsedMilliseconds); - - var totalReceived = 0; - var received = 0; - - try - { - var buffer = new byte[chunkSize]; - - do - { - received = socket.Receive(buffer); - totalReceived += received; - } while (received > 0 && totalReceived < responseSize); - } - catch (SocketException) { } - catch (IOException) - { - // Socket.Receive could throw, and that is fine - } - - // Since we expect writes to be cut off by the rate control, we should never see the entire response - logger.LogInformation("Received {totalReceived} bytes", totalReceived); - Assert.NotEqual(responseSize, totalReceived); } } } @@ -2683,18 +2716,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests [Fact] public async Task HttpsConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate() { - const int chunkSize = 64 * 1024; - const int chunks = 128; + const int chunkSize = 1024; + const int chunks = 256 * 1024; + var chunkData = new byte[chunkSize]; var certificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); - var messageLogged = new ManualResetEventSlim(); - var aborted = new ManualResetEventSlim(); + var responseRateTimeoutMessageLogged = new TaskCompletionSource(); + var connectionStopMessageLogged = new TaskCompletionSource(); + var aborted = new TaskCompletionSource(); var mockKestrelTrace = new Mock(); mockKestrelTrace .Setup(trace => trace.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny())) - .Callback(() => messageLogged.Set()); + .Callback(() => responseRateTimeoutMessageLogged.SetResult(null)); + mockKestrelTrace + .Setup(trace => trace.ConnectionStop(It.IsAny())) + .Callback(() => connectionStopMessageLogged.SetResult(null)); var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object) { @@ -2720,41 +2758,183 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { context.RequestAborted.Register(() => { - aborted.Set(); + aborted.SetResult(null); }); context.Response.ContentLength = chunks * chunkSize; for (var i = 0; i < chunks; i++) { - await context.Response.WriteAsync(new string('a', chunkSize), context.RequestAborted); + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); } }, testContext, listenOptions)) { - using (var client = new TcpClient()) + using (var connection = server.CreateConnection()) { - await client.ConnectAsync(IPAddress.Loopback, server.Port); - - using (var sslStream = new SslStream(client.GetStream(), false, (sender, cert, chain, errors) => true, null)) + using (var sslStream = new SslStream(connection.Reader.BaseStream, false, (sender, cert, chain, errors) => true, null)) { await sslStream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); await sslStream.WriteAsync(request, 0, request.Length); - Assert.True(aborted.Wait(TimeSpan.FromSeconds(60))); + await aborted.Task.TimeoutAfter(TimeSpan.FromSeconds(60)); + await responseRateTimeoutMessageLogged.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + await connectionStopMessageLogged.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); - using (var reader = new StreamReader(sslStream, encoding: Encoding.ASCII, detectEncodingFromByteOrderMarks: false, bufferSize: 1024, leaveOpen: false)) + // Temporary workaround for a deadlock when reading from an aborted client SslStream on Mac and Linux. + if (TestPlatformHelper.IsWindows) { - await reader.ReadToEndAsync().TimeoutAfter(TestConstants.DefaultTimeout); + await AssertStreamAborted(sslStream, chunkSize * chunks); + } + else + { + await AssertStreamAborted(connection.Reader.BaseStream, chunkSize * chunks); } - - Assert.True(messageLogged.Wait(TestConstants.DefaultTimeout)); } } } } + [Fact] + public async Task ConnectionNotClosedWhenClientSatisfiesMinimumDataRateGivenLargeResponseChunks() + { + var chunkSize = 64 * 128 * 1024; + var chunkCount = 4; + var chunkData = new byte[chunkSize]; + + var requestAborted = false; + var mockKestrelTrace = new Mock(); + + var testContext = new TestServiceContext + { + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + async Task App(HttpContext context) + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + for (var i = 0; i < chunkCount; i++) + { + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); + } + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Close the connection with the last request so AssertStreamCompleted actually completes. + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + var minTotalOutputSize = chunkCount * chunkSize; + + // Make sure consuming a single chunk exceeds the 2 second timeout. + var targetBytesPerSecond = chunkSize / 4; + await AssertStreamCompleted(connection.Reader.BaseStream, minTotalOutputSize, targetBytesPerSecond); + + mockKestrelTrace.Verify(t => t.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny()), Times.Never()); + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + Assert.False(requestAborted); + } + } + } + + [Fact] + public async Task ConnectionNotClosedWhenClientSatisfiesMinimumDataRateGivenLargeResponseHeaders() + { + var headerSize = 1024 * 1024; // 1 MB for each header value + var headerCount = 64; // 64 MB of headers per response + var requestCount = 4; // Minimum of 256 MB of total response headers + var headerValue = new string('a', headerSize); + var headerStringValues = new StringValues(Enumerable.Repeat(headerValue, headerCount).ToArray()); + + var requestAborted = false; + var mockKestrelTrace = new Mock(); + + var testContext = new TestServiceContext + { + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + async Task App(HttpContext context) + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + context.Response.Headers[$"X-Custom-Header"] = headerStringValues; + context.Response.ContentLength = 0; + + await context.Response.Body.FlushAsync(); + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + for (var i = 0; i < requestCount - 1; i++) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + } + + // Close the connection with the last request so AssertStreamCompleted actually completes. + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + var responseSize = headerSize * headerCount; + var minTotalOutputSize = requestCount * responseSize; + + // Make sure consuming a single set of response headers exceeds the 2 second timeout. + var targetBytesPerSecond = responseSize / 4; + await AssertStreamCompleted(connection.Reader.BaseStream, minTotalOutputSize, targetBytesPerSecond); + + mockKestrelTrace.Verify(t => t.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny()), Times.Never()); + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + Assert.False(requestAborted); + } + } + } + + [Fact] public async Task NonZeroContentLengthFor304StatusCodeIsAllowed() { @@ -2784,6 +2964,56 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + private async Task AssertStreamAborted(Stream stream, int totalBytes) + { + var receiveBuffer = new byte[64 * 1024]; + var totalReceived = 0; + + try + { + while (totalReceived < totalBytes) + { + var bytes = await stream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length).TimeoutAfter(TimeSpan.FromSeconds(10)); + + if (bytes == 0) + { + break; + } + + totalReceived += bytes; + } + } + catch (IOException) + { + // This is expected given an abort. + } + + Assert.True(totalReceived < totalBytes, $"{nameof(AssertStreamAborted)} Stream completed successfully."); + } + + private async Task AssertStreamCompleted(Stream stream, long minimumBytes, int targetBytesPerSecond) + { + var receiveBuffer = new byte[64 * 1024]; + var received = 0; + var totalReceived = 0; + var startTime = DateTimeOffset.UtcNow; + + do + { + received = await stream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length); + totalReceived += received; + + var expectedTimeElapsed = TimeSpan.FromSeconds(totalReceived / targetBytesPerSecond); + var timeElapsed = DateTimeOffset.UtcNow - startTime; + if (timeElapsed < expectedTimeElapsed) + { + await Task.Delay(expectedTimeElapsed - timeElapsed); + } + } while (received > 0); + + Assert.True(totalReceived >= minimumBytes, $"{nameof(AssertStreamCompleted)} Stream aborted prematurely."); + } + public static TheoryData NullHeaderData { get diff --git a/test/Kestrel.FunctionalTests/UpgradeTests.cs b/test/Kestrel.FunctionalTests/UpgradeTests.cs index 8ff867966d..d7a0f5073f 100644 --- a/test/Kestrel.FunctionalTests/UpgradeTests.cs +++ b/test/Kestrel.FunctionalTests/UpgradeTests.cs @@ -297,5 +297,42 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var exception = await Assert.ThrowsAsync(async () => await upgradeTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(60))); Assert.Equal(CoreStrings.UpgradedConnectionLimitReached, exception.Message); } + + [Fact] + public async Task DoesNotThrowOnFin() + { + var appCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + var duplexStream = await feature.UpgradeAsync(); + + try + { + await duplexStream.CopyToAsync(Stream.Null); + appCompletedTcs.SetResult(null); + } + catch (Exception ex) + { + appCompletedTcs.SetException(ex); + throw; + } + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + } + + await appCompletedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + } } } diff --git a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs index ef024facfb..e269bda59c 100644 --- a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs +++ b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs @@ -7,6 +7,7 @@ using System.Collections.Concurrent; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -304,6 +305,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests outputProducer.Write((writableBuffer, state) => { writableBuffer.Write(state); + return state.Count; }, halfWriteBehindBuffer); @@ -729,10 +731,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests var socket = new MockSocket(_mockLibuv, _libuvThread.Loop.ThreadId, transportContext.Log); var consumer = new LibuvOutputConsumer(pair.Application.Input, _libuvThread, socket, "0", transportContext.Log); + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + var http1Connection = new Http1Connection(new Http1ConnectionContext { ServiceContext = serviceContext, - ConnectionFeatures = new FeatureCollection(), + ConnectionFeatures = connectionFeatures, MemoryPool = _memoryPool, TimeoutControl = Mock.Of(), Application = pair.Application, diff --git a/tools/CodeGenerator/KnownHeaders.cs b/tools/CodeGenerator/KnownHeaders.cs index 08646dc61a..badf9087a7 100644 --- a/tools/CodeGenerator/KnownHeaders.cs +++ b/tools/CodeGenerator/KnownHeaders.cs @@ -548,7 +548,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return true; }} {(loop.ClassName == "HttpResponseHeaders" ? $@" - internal void CopyToFast(ref BufferWriter output) + internal void CopyToFast(ref CountingBufferWriter output) {{ var tempBits = _bits | (_contentLength.HasValue ? {1L << 63}L : 0); {Each(loop.Headers.Where(header => header.Identifier != "ContentLength").OrderBy(h => !h.PrimaryHeader), header => $@"