diff --git a/src/Servers/Connections.Abstractions/src/ConnectionBuilder.cs b/src/Servers/Connections.Abstractions/src/ConnectionBuilder.cs new file mode 100644 index 0000000000..b75e92b60f --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionBuilder.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Connections +{ + public class ConnectionBuilder : IConnectionBuilder + { + private readonly IList> _components = new List>(); + + public IServiceProvider ApplicationServices { get; } + + public ConnectionBuilder(IServiceProvider applicationServices) + { + ApplicationServices = applicationServices; + } + + public IConnectionBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public ConnectionDelegate Build() + { + ConnectionDelegate app = features => + { + return Task.CompletedTask; + }; + + foreach (var component in _components.Reverse()) + { + app = component(app); + } + + return app; + } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/ConnectionBuilderExtensions.cs b/src/Servers/Connections.Abstractions/src/ConnectionBuilderExtensions.cs new file mode 100644 index 0000000000..100917b009 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionBuilderExtensions.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.Connections +{ + public static class ConnectionBuilderExtensions + { + public static IConnectionBuilder UseConnectionHandler(this IConnectionBuilder connectionBuilder) where TConnectionHandler : ConnectionHandler + { + var handler = ActivatorUtilities.GetServiceOrCreateInstance(connectionBuilder.ApplicationServices); + + // This is a terminal middleware, so there's no need to use the 'next' parameter + return connectionBuilder.Run(connection => handler.OnConnectedAsync(connection)); + } + + public static IConnectionBuilder Use(this IConnectionBuilder connectionBuilder, Func, Task> middleware) + { + return connectionBuilder.Use(next => + { + return context => + { + Func simpleNext = () => next(context); + return middleware(context, simpleNext); + }; + }); + } + + public static IConnectionBuilder Run(this IConnectionBuilder connectionBuilder, Func middleware) + { + return connectionBuilder.Use(next => + { + return context => + { + return middleware(context); + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/ConnectionContext.cs b/src/Servers/Connections.Abstractions/src/ConnectionContext.cs new file mode 100644 index 0000000000..bb942f2627 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionContext.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.IO.Pipelines; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Connections +{ + public abstract class ConnectionContext + { + public abstract string ConnectionId { get; set; } + + public abstract IFeatureCollection Features { get; } + + public abstract IDictionary Items { get; set; } + + public abstract IDuplexPipe Transport { get; set; } + + public virtual void Abort(ConnectionAbortedException abortReason) + { + // We expect this to be overridden, but this helps maintain back compat + // with implementations of ConnectionContext that predate the addition of + // ConnectioContext.Abort() + Features.Get()?.Abort(); + } + + public virtual void Abort() => Abort(new ConnectionAbortedException("The connection was aborted by the application.")); + } +} diff --git a/src/Servers/Connections.Abstractions/src/ConnectionDelegate.cs b/src/Servers/Connections.Abstractions/src/ConnectionDelegate.cs new file mode 100644 index 0000000000..f0d64d1587 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionDelegate.cs @@ -0,0 +1,6 @@ +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Connections +{ + public delegate Task ConnectionDelegate(ConnectionContext connection); +} diff --git a/src/Servers/Connections.Abstractions/src/ConnectionHandler.cs b/src/Servers/Connections.Abstractions/src/ConnectionHandler.cs new file mode 100644 index 0000000000..e9e208d61a --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionHandler.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Connections +{ + /// + /// Represents an end point that multiple connections connect to. For HTTP, endpoints are URLs, for non HTTP it can be a TCP listener (or similar) + /// + public abstract class ConnectionHandler + { + /// + /// Called when a new connection is accepted to the endpoint + /// + /// The new + /// A that represents the connection lifetime. When the task completes, the connection is complete. + public abstract Task OnConnectedAsync(ConnectionContext connection); + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/ConnectionItems.cs b/src/Servers/Connections.Abstractions/src/ConnectionItems.cs new file mode 100644 index 0000000000..0f01a62111 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/ConnectionItems.cs @@ -0,0 +1,119 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Connections +{ + public class ConnectionItems : IDictionary + { + public ConnectionItems() + : this(new Dictionary()) + { + } + + public ConnectionItems(IDictionary items) + { + Items = items; + } + + public IDictionary Items { get; } + + // Replace the indexer with one that returns null for missing values + object IDictionary.this[object key] + { + get + { + if (Items.TryGetValue(key, out var value)) + { + return value; + } + return null; + } + set { Items[key] = value; } + } + + void IDictionary.Add(object key, object value) + { + Items.Add(key, value); + } + + bool IDictionary.ContainsKey(object key) + { + return Items.ContainsKey(key); + } + + ICollection IDictionary.Keys + { + get { return Items.Keys; } + } + + bool IDictionary.Remove(object key) + { + return Items.Remove(key); + } + + bool IDictionary.TryGetValue(object key, out object value) + { + return Items.TryGetValue(key, out value); + } + + ICollection IDictionary.Values + { + get { return Items.Values; } + } + + void ICollection>.Add(KeyValuePair item) + { + Items.Add(item); + } + + void ICollection>.Clear() + { + Items.Clear(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + return Items.Contains(item); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + Items.CopyTo(array, arrayIndex); + } + + int ICollection>.Count + { + get { return Items.Count; } + } + + bool ICollection>.IsReadOnly + { + get { return Items.IsReadOnly; } + } + + bool ICollection>.Remove(KeyValuePair item) + { + object value; + if (Items.TryGetValue(item.Key, out value) && Equals(item.Value, value)) + { + return Items.Remove(item.Key); + } + return false; + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return Items.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return Items.GetEnumerator(); + } + } +} diff --git a/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs b/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs new file mode 100644 index 0000000000..56c5d0424f --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs @@ -0,0 +1,78 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Security.Claims; +using System.Threading; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Connections +{ + public class DefaultConnectionContext : ConnectionContext, + IDisposable, + IConnectionIdFeature, + IConnectionItemsFeature, + IConnectionTransportFeature, + IConnectionUserFeature, + IConnectionLifetimeFeature + { + private CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); + + public DefaultConnectionContext() : + this(Guid.NewGuid().ToString()) + { + ConnectionClosed = _connectionClosedTokenSource.Token; + } + + /// + /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. + /// The caller is expected to set the and pipes manually. + /// + /// + public DefaultConnectionContext(string id) + { + ConnectionId = id; + + Features = new FeatureCollection(); + Features.Set(this); + Features.Set(this); + Features.Set(this); + Features.Set(this); + Features.Set(this); + } + + public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) + : this(id) + { + Transport = transport; + Application = application; + } + + public override string ConnectionId { get; set; } + + public override IFeatureCollection Features { get; } + + public ClaimsPrincipal User { get; set; } + + public override IDictionary Items { get; set; } = new ConnectionItems(); + + public IDuplexPipe Application { get; set; } + + public override IDuplexPipe Transport { get; set; } + + public CancellationToken ConnectionClosed { get; set; } + + public override void Abort(ConnectionAbortedException abortReason) + { + ThreadPool.QueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource); + } + + public void Dispose() + { + _connectionClosedTokenSource.Dispose(); + } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Exceptions/AddressInUseException.cs b/src/Servers/Connections.Abstractions/src/Exceptions/AddressInUseException.cs new file mode 100644 index 0000000000..817abe8998 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Exceptions/AddressInUseException.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Connections +{ + public class AddressInUseException : InvalidOperationException + { + public AddressInUseException(string message) : base(message) + { + } + + public AddressInUseException(string message, Exception inner) : base(message, inner) + { + } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionAbortedException.cs b/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionAbortedException.cs new file mode 100644 index 0000000000..7615010cc7 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionAbortedException.cs @@ -0,0 +1,21 @@ +using System; + +namespace Microsoft.AspNetCore.Connections +{ + public class ConnectionAbortedException : OperationCanceledException + { + public ConnectionAbortedException() : + this("The connection was aborted") + { + + } + + public ConnectionAbortedException(string message) : base(message) + { + } + + public ConnectionAbortedException(string message, Exception inner) : base(message, inner) + { + } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionResetException.cs b/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionResetException.cs new file mode 100644 index 0000000000..78765bc25a --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Exceptions/ConnectionResetException.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; + +namespace Microsoft.AspNetCore.Connections +{ + public class ConnectionResetException : IOException + { + public ConnectionResetException(string message) : base(message) + { + } + + public ConnectionResetException(string message, Exception inner) : base(message, inner) + { + } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionHeartbeatFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionHeartbeatFeature.cs new file mode 100644 index 0000000000..cea40d8bdc --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionHeartbeatFeature.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionHeartbeatFeature + { + void OnHeartbeat(Action action, object state); + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionIdFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionIdFeature.cs new file mode 100644 index 0000000000..2fa7ebbadf --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionIdFeature.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionIdFeature + { + string ConnectionId { get; set; } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionInherentKeepAliveFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionInherentKeepAliveFeature.cs new file mode 100644 index 0000000000..d5751f8bcf --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionInherentKeepAliveFeature.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Connections.Features +{ + /// + /// Indicates if the connection transport has an "inherent keep-alive", which means that the transport will automatically + /// inform the client that it is still present. + /// + /// + /// The most common example of this feature is the Long Polling HTTP transport, which must (due to HTTP limitations) terminate + /// each poll within a particular interval and return a signal indicating "the server is still here, but there is no data yet". + /// This feature allows applications to add keep-alive functionality, but limit it only to transports that don't have some kind + /// of inherent keep-alive. + /// + public interface IConnectionInherentKeepAliveFeature + { + bool HasInherentKeepAlive { get; } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionItemsFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionItemsFeature.cs new file mode 100644 index 0000000000..a3aef44310 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionItemsFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionItemsFeature + { + IDictionary Items { get; set; } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionLifetimeFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionLifetimeFeature.cs new file mode 100644 index 0000000000..8f804de898 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionLifetimeFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionLifetimeFeature + { + CancellationToken ConnectionClosed { get; set; } + void Abort(); + } +} diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionTransportFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionTransportFeature.cs new file mode 100644 index 0000000000..7468dbc722 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionTransportFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionTransportFeature + { + IDuplexPipe Transport { get; set; } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Features/IConnectionUserFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IConnectionUserFeature.cs new file mode 100644 index 0000000000..3efb362fc7 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IConnectionUserFeature.cs @@ -0,0 +1,9 @@ +using System.Security.Claims; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IConnectionUserFeature + { + ClaimsPrincipal User { get; set; } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/Features/IMemoryPoolFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IMemoryPoolFeature.cs new file mode 100644 index 0000000000..0a7e28533e --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IMemoryPoolFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface IMemoryPoolFeature + { + MemoryPool MemoryPool { get; } + } +} diff --git a/src/Servers/Connections.Abstractions/src/Features/ITransferFormatFeature.cs b/src/Servers/Connections.Abstractions/src/Features/ITransferFormatFeature.cs new file mode 100644 index 0000000000..ea39a92760 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/ITransferFormatFeature.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Connections.Features +{ + public interface ITransferFormatFeature + { + TransferFormat SupportedFormats { get; } + TransferFormat ActiveFormat { get; set; } + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/IConnectionBuilder.cs b/src/Servers/Connections.Abstractions/src/IConnectionBuilder.cs new file mode 100644 index 0000000000..4825748292 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/IConnectionBuilder.cs @@ -0,0 +1,13 @@ +using System; + +namespace Microsoft.AspNetCore.Connections +{ + public interface IConnectionBuilder + { + IServiceProvider ApplicationServices { get; } + + IConnectionBuilder Use(Func middleware); + + ConnectionDelegate Build(); + } +} diff --git a/src/Servers/Connections.Abstractions/src/Microsoft.AspNetCore.Connections.Abstractions.csproj b/src/Servers/Connections.Abstractions/src/Microsoft.AspNetCore.Connections.Abstractions.csproj new file mode 100644 index 0000000000..5546aef7b8 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Microsoft.AspNetCore.Connections.Abstractions.csproj @@ -0,0 +1,17 @@ + + + + Core components of ASP.NET Core networking protocol stack. + netstandard2.0 + true + aspnetcore + CS1591;$(NoWarn) + + + + + + + + + diff --git a/src/Servers/Connections.Abstractions/src/TransferFormat.cs b/src/Servers/Connections.Abstractions/src/TransferFormat.cs new file mode 100644 index 0000000000..03fd936159 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/TransferFormat.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Connections +{ + [Flags] + public enum TransferFormat + { + Binary = 0x01, + Text = 0x02 + } +} \ No newline at end of file diff --git a/src/Servers/Connections.Abstractions/src/baseline.netcore.json b/src/Servers/Connections.Abstractions/src/baseline.netcore.json new file mode 100644 index 0000000000..7a73a41bfd --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/baseline.netcore.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/src/Servers/Directory.Build.props b/src/Servers/Directory.Build.props new file mode 100644 index 0000000000..6b35802689 --- /dev/null +++ b/src/Servers/Directory.Build.props @@ -0,0 +1,9 @@ + + + + + $(RepositoryRoot)obj\$(MSBuildProjectName)\ + $(RepositoryRoot)bin\$(MSBuildProjectName)\ + + + diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/AdaptedPipeline.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/AdaptedPipeline.cs new file mode 100644 index 0000000000..2cbb76e5b1 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/AdaptedPipeline.cs @@ -0,0 +1,170 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + public class AdaptedPipeline : IDuplexPipe + { + private static readonly int MinAllocBufferSize = KestrelMemoryPool.MinimumSegmentSize / 2; + + private readonly IDuplexPipe _transport; + + public AdaptedPipeline(IDuplexPipe transport, + Pipe inputPipe, + Pipe outputPipe, + IKestrelTrace log) + { + _transport = transport; + Input = inputPipe; + Output = outputPipe; + Log = log; + } + + public Pipe Input { get; } + + public Pipe Output { get; } + + public IKestrelTrace Log { get; } + + PipeReader IDuplexPipe.Input => Input.Reader; + + PipeWriter IDuplexPipe.Output => Output.Writer; + + public async Task RunAsync(Stream stream) + { + var inputTask = ReadInputAsync(stream); + var outputTask = WriteOutputAsync(stream); + + await inputTask; + await outputTask; + } + + private async Task WriteOutputAsync(Stream stream) + { + try + { + if (stream == null) + { + return; + } + + while (true) + { + var result = await Output.Reader.ReadAsync(); + var buffer = result.Buffer; + + try + { + if (buffer.IsEmpty) + { + if (result.IsCompleted) + { + break; + } + await stream.FlushAsync(); + } + else if (buffer.IsSingleSegment) + { +#if NETCOREAPP2_1 + await stream.WriteAsync(buffer.First); +#else + var array = buffer.First.GetArray(); + await stream.WriteAsync(array.Array, array.Offset, array.Count); +#endif + } + else + { + foreach (var memory in buffer) + { +#if NETCOREAPP2_1 + await stream.WriteAsync(memory); +#else + var array = memory.GetArray(); + await stream.WriteAsync(array.Array, array.Offset, array.Count); +#endif + } + } + } + finally + { + Output.Reader.AdvanceTo(buffer.End); + } + } + } + catch (Exception ex) + { + Log.LogError(0, ex, $"{nameof(AdaptedPipeline)}.{nameof(WriteOutputAsync)}"); + } + finally + { + Output.Reader.Complete(); + _transport.Output.Complete(); + } + } + + private async Task ReadInputAsync(Stream stream) + { + Exception error = null; + + try + { + if (stream == null) + { + // REVIEW: Do we need an exception here? + return; + } + + while (true) + { + + var outputBuffer = Input.Writer.GetMemory(MinAllocBufferSize); +#if NETCOREAPP2_1 + var bytesRead = await stream.ReadAsync(outputBuffer); +#else + var array = outputBuffer.GetArray(); + var bytesRead = await stream.ReadAsync(array.Array, array.Offset, array.Count); +#endif + Input.Writer.Advance(bytesRead); + + if (bytesRead == 0) + { + // FIN + break; + } + + var result = await Input.Writer.FlushAsync(); + + if (result.IsCompleted) + { + break; + } + } + } + catch (Exception ex) + { + // Don't rethrow the exception. It should be handled by the Pipeline consumer. + error = ex; + } + finally + { + Input.Writer.Complete(error); + // The application could have ended the input pipe so complete + // the transport pipe as well + _transport.Input.Complete(); + } + } + + public void Dispose() + { + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/ConnectionAdapterContext.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/ConnectionAdapterContext.cs new file mode 100644 index 0000000000..3896e1cf85 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/ConnectionAdapterContext.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + // Even though this only includes the non-adapted ConnectionStream currently, this is a context in case + // we want to add more connection metadata later. + public class ConnectionAdapterContext + { + internal ConnectionAdapterContext(ConnectionContext connectionContext, Stream connectionStream) + { + ConnectionContext = connectionContext; + ConnectionStream = connectionStream; + } + + internal ConnectionContext ConnectionContext { get; } + + public IFeatureCollection Features => ConnectionContext.Features; + + public Stream ConnectionStream { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/IAdaptedConnection.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/IAdaptedConnection.cs new file mode 100644 index 0000000000..5960490e2b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/IAdaptedConnection.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + public interface IAdaptedConnection : IDisposable + { + Stream ConnectionStream { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/IConnectionAdapter.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/IConnectionAdapter.cs new file mode 100644 index 0000000000..e0249d5545 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/IConnectionAdapter.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + public interface IConnectionAdapter + { + bool IsHttps { get; } + Task OnConnectionAsync(ConnectionAdapterContext context); + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingConnectionAdapter.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingConnectionAdapter.cs new file mode 100644 index 0000000000..1afd32d1d6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingConnectionAdapter.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + public class LoggingConnectionAdapter : IConnectionAdapter + { + private readonly ILogger _logger; + + public LoggingConnectionAdapter(ILogger logger) + { + if (logger == null) + { + throw new ArgumentNullException(nameof(logger)); + } + + _logger = logger; + } + + public bool IsHttps => false; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + return Task.FromResult( + new LoggingAdaptedConnection(context.ConnectionStream, _logger)); + } + + private class LoggingAdaptedConnection : IAdaptedConnection + { + public LoggingAdaptedConnection(Stream rawStream, ILogger logger) + { + ConnectionStream = new LoggingStream(rawStream, logger); + } + + public Stream ConnectionStream { get; } + + public void Dispose() + { + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingStream.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingStream.cs new file mode 100644 index 0000000000..0b5fdf3975 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/LoggingStream.cs @@ -0,0 +1,246 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + internal class LoggingStream : Stream + { + private readonly Stream _inner; + private readonly ILogger _logger; + + public LoggingStream(Stream inner, ILogger logger) + { + _inner = inner; + _logger = logger; + } + + public override bool CanRead + { + get + { + return _inner.CanRead; + } + } + + public override bool CanSeek + { + get + { + return _inner.CanSeek; + } + } + + public override bool CanWrite + { + get + { + return _inner.CanWrite; + } + } + + public override long Length + { + get + { + return _inner.Length; + } + } + + public override long Position + { + get + { + return _inner.Position; + } + + set + { + _inner.Position = value; + } + } + + public override void Flush() + { + _inner.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _inner.FlushAsync(cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + int read = _inner.Read(buffer, offset, count); + Log("Read", new ReadOnlySpan(buffer, offset, read)); + return read; + } + +#if NETCOREAPP2_1 + public override int Read(Span destination) + { + int read = _inner.Read(destination); + Log("Read", destination.Slice(0, read)); + return read; + } +#endif + + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken); + Log("ReadAsync", new ReadOnlySpan(buffer, offset, read)); + return read; + } + +#if NETCOREAPP2_1 + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + int read = await _inner.ReadAsync(destination, cancellationToken); + Log("ReadAsync", destination.Span.Slice(0, read)); + return read; + } +#endif + + public override long Seek(long offset, SeekOrigin origin) + { + return _inner.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _inner.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + Log("Write", new ReadOnlySpan(buffer, offset, count)); + _inner.Write(buffer, offset, count); + } + +#if NETCOREAPP2_1 + public override void Write(ReadOnlySpan source) + { + Log("Write", source); + _inner.Write(source); + } +#endif + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Log("WriteAsync", new ReadOnlySpan(buffer, offset, count)); + return _inner.WriteAsync(buffer, offset, count, cancellationToken); + } + +#if NETCOREAPP2_1 + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + Log("WriteAsync", source.Span); + return _inner.WriteAsync(source, cancellationToken); + } +#endif + + private void Log(string method, ReadOnlySpan buffer) + { + var builder = new StringBuilder($"{method}[{buffer.Length}] "); + + // Write the hex + for (int i = 0; i < buffer.Length; i++) + { + builder.Append(buffer[i].ToString("X2")); + builder.Append(" "); + } + builder.AppendLine(); + // Write the bytes as if they were ASCII + for (int i = 0; i < buffer.Length; i++) + { + builder.Append((char)buffer[i]); + } + + _logger.LogDebug(builder.ToString()); + } + + // The below APM methods call the underlying Read/WriteAsync methods which will still be logged. + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = ReadAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override int EndRead(IAsyncResult asyncResult) + { + return ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = ReadAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(task2.Result); + } + }, tcs, cancellationToken); + return tcs.Task; + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = WriteAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override void EndWrite(IAsyncResult asyncResult) + { + ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = WriteAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(null); + } + }, tcs, cancellationToken); + return tcs.Task; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/Internal/RawStream.cs b/src/Servers/Kestrel/Core/src/Adapter/Internal/RawStream.cs new file mode 100644 index 0000000000..084eed2418 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/Internal/RawStream.cs @@ -0,0 +1,217 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using System.Buffers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal +{ + public class RawStream : Stream + { + private readonly PipeReader _input; + private readonly PipeWriter _output; + + public RawStream(PipeReader input, PipeWriter output) + { + _input = input; + _output = output; + } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length + { + get + { + throw new NotSupportedException(); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(); + } + set + { + throw new NotSupportedException(); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + // ValueTask uses .GetAwaiter().GetResult() if necessary + // https://github.com/dotnet/corefx/blob/f9da3b4af08214764a51b2331f3595ffaf162abe/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs#L156 + return ReadAsyncInternal(new Memory(buffer, offset, count)).Result; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReadAsyncInternal(new Memory(buffer, offset, count)).AsTask(); + } + +#if NETCOREAPP2_1 + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + return ReadAsyncInternal(destination); + } +#endif + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (buffer != null) + { + _output.Write(new ReadOnlySpan(buffer, offset, count)); + } + + await _output.FlushAsync(cancellationToken); + } + +#if NETCOREAPP2_1 + public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + _output.Write(source.Span); + await _output.FlushAsync(cancellationToken); + } +#endif + + public override void Flush() + { + FlushAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return WriteAsync(null, 0, 0, cancellationToken); + } + + private async ValueTask ReadAsyncInternal(Memory destination) + { + while (true) + { + var result = await _input.ReadAsync(); + var readableBuffer = result.Buffer; + try + { + if (!readableBuffer.IsEmpty) + { + // buffer.Count is int + var count = (int) Math.Min(readableBuffer.Length, destination.Length); + readableBuffer = readableBuffer.Slice(0, count); + readableBuffer.CopyTo(destination.Span); + return count; + } + + if (result.IsCompleted) + { + return 0; + } + } + finally + { + _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + } + } + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = ReadAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override int EndRead(IAsyncResult asyncResult) + { + return ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = ReadAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(task2.Result); + } + }, tcs, cancellationToken); + return tcs.Task; + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = WriteAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override void EndWrite(IAsyncResult asyncResult) + { + ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = WriteAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(null); + } + }, tcs, cancellationToken); + return tcs.Task; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Adapter/ListenOptionsConnectionLoggingExtensions.cs b/src/Servers/Kestrel/Core/src/Adapter/ListenOptionsConnectionLoggingExtensions.cs new file mode 100644 index 0000000000..11306602c6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Adapter/ListenOptionsConnectionLoggingExtensions.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class ListenOptionsConnectionLoggingExtensions + { + /// + /// Emits verbose logs for bytes read from and written to the connection. + /// + /// + /// The . + /// + public static ListenOptions UseConnectionLogging(this ListenOptions listenOptions) + { + return listenOptions.UseConnectionLogging(nameof(LoggingConnectionAdapter)); + } + + /// + /// Emits verbose logs for bytes read from and written to the connection. + /// + /// + /// The . + /// + public static ListenOptions UseConnectionLogging(this ListenOptions listenOptions, string loggerName) + { + var loggerFactory = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + var logger = loggerFactory.CreateLogger(loggerName ?? nameof(LoggingConnectionAdapter)); + listenOptions.ConnectionAdapters.Add(new LoggingConnectionAdapter(logger)); + return listenOptions; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs b/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs new file mode 100644 index 0000000000..2639337dd7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs @@ -0,0 +1,37 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + internal class AnyIPListenOptions : ListenOptions + { + internal AnyIPListenOptions(int port) + : base(new IPEndPoint(IPAddress.IPv6Any, port)) + { + } + + internal override async Task BindAsync(AddressBindContext context) + { + // when address is 'http://hostname:port', 'http://*:port', or 'http://+:port' + try + { + await base.BindAsync(context).ConfigureAwait(false); + } + catch (Exception ex) when (!(ex is IOException)) + { + context.Logger.LogDebug(CoreStrings.FormatFallbackToIPv4Any(IPEndPoint.Port)); + + // for machines that do not support IPv6 + IPEndPoint = new IPEndPoint(IPAddress.Any, IPEndPoint.Port); + await base.BindAsync(context).ConfigureAwait(false); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/BadHttpRequestException.cs b/src/Servers/Kestrel/Core/src/BadHttpRequestException.cs new file mode 100644 index 0000000000..033c8d9e1b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/BadHttpRequestException.cs @@ -0,0 +1,176 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + public sealed class BadHttpRequestException : IOException + { + private BadHttpRequestException(string message, int statusCode, RequestRejectionReason reason) + : this(message, statusCode, reason, null) + { } + + private BadHttpRequestException(string message, int statusCode, RequestRejectionReason reason, HttpMethod? requiredMethod) + : base(message) + { + StatusCode = statusCode; + Reason = reason; + + if (requiredMethod.HasValue) + { + AllowedHeader = HttpUtilities.MethodToString(requiredMethod.Value); + } + } + + internal int StatusCode { get; } + + internal StringValues AllowedHeader { get; } + + internal RequestRejectionReason Reason { get; } + + [StackTraceHidden] + internal static void Throw(RequestRejectionReason reason) + { + throw GetException(reason); + } + + [StackTraceHidden] + public static void Throw(RequestRejectionReason reason, HttpMethod method) + => throw GetException(reason, method.ToString().ToUpperInvariant()); + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static BadHttpRequestException GetException(RequestRejectionReason reason) + { + BadHttpRequestException ex; + switch (reason) + { + case RequestRejectionReason.InvalidRequestHeadersNoCRLF: + ex = new BadHttpRequestException(CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidRequestLine: + ex = new BadHttpRequestException(CoreStrings.BadRequest_InvalidRequestLine, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.MalformedRequestInvalidHeaders: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MalformedRequestInvalidHeaders, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.MultipleContentLengths: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MultipleContentLengths, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.UnexpectedEndOfRequestContent: + ex = new BadHttpRequestException(CoreStrings.BadRequest_UnexpectedEndOfRequestContent, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.BadChunkSuffix: + ex = new BadHttpRequestException(CoreStrings.BadRequest_BadChunkSuffix, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.BadChunkSizeData: + ex = new BadHttpRequestException(CoreStrings.BadRequest_BadChunkSizeData, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.ChunkedRequestIncomplete: + ex = new BadHttpRequestException(CoreStrings.BadRequest_ChunkedRequestIncomplete, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidCharactersInHeaderName: + ex = new BadHttpRequestException(CoreStrings.BadRequest_InvalidCharactersInHeaderName, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.RequestLineTooLong: + ex = new BadHttpRequestException(CoreStrings.BadRequest_RequestLineTooLong, StatusCodes.Status414UriTooLong, reason); + break; + case RequestRejectionReason.HeadersExceedMaxTotalSize: + ex = new BadHttpRequestException(CoreStrings.BadRequest_HeadersExceedMaxTotalSize, StatusCodes.Status431RequestHeaderFieldsTooLarge, reason); + break; + case RequestRejectionReason.TooManyHeaders: + ex = new BadHttpRequestException(CoreStrings.BadRequest_TooManyHeaders, StatusCodes.Status431RequestHeaderFieldsTooLarge, reason); + break; + case RequestRejectionReason.RequestBodyTooLarge: + ex = new BadHttpRequestException(CoreStrings.BadRequest_RequestBodyTooLarge, StatusCodes.Status413PayloadTooLarge, reason); + break; + case RequestRejectionReason.RequestHeadersTimeout: + ex = new BadHttpRequestException(CoreStrings.BadRequest_RequestHeadersTimeout, StatusCodes.Status408RequestTimeout, reason); + break; + case RequestRejectionReason.RequestBodyTimeout: + ex = new BadHttpRequestException(CoreStrings.BadRequest_RequestBodyTimeout, StatusCodes.Status408RequestTimeout, reason); + break; + case RequestRejectionReason.OptionsMethodRequired: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MethodNotAllowed, StatusCodes.Status405MethodNotAllowed, reason, HttpMethod.Options); + break; + case RequestRejectionReason.ConnectMethodRequired: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MethodNotAllowed, StatusCodes.Status405MethodNotAllowed, reason, HttpMethod.Connect); + break; + case RequestRejectionReason.MissingHostHeader: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MissingHostHeader, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.MultipleHostHeaders: + ex = new BadHttpRequestException(CoreStrings.BadRequest_MultipleHostHeaders, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidHostHeader: + ex = new BadHttpRequestException(CoreStrings.BadRequest_InvalidHostHeader, StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.UpgradeRequestCannotHavePayload: + ex = new BadHttpRequestException(CoreStrings.BadRequest_UpgradeRequestCannotHavePayload, StatusCodes.Status400BadRequest, reason); + break; + default: + ex = new BadHttpRequestException(CoreStrings.BadRequest, StatusCodes.Status400BadRequest, reason); + break; + } + return ex; + } + + [StackTraceHidden] + internal static void Throw(RequestRejectionReason reason, string detail) + { + throw GetException(reason, detail); + } + + [StackTraceHidden] + internal static void Throw(RequestRejectionReason reason, in StringValues detail) + { + throw GetException(reason, detail.ToString()); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + internal static BadHttpRequestException GetException(RequestRejectionReason reason, string detail) + { + BadHttpRequestException ex; + switch (reason) + { + case RequestRejectionReason.InvalidRequestLine: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidRequestTarget: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidRequestHeader: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidContentLength: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_InvalidContentLength_Detail(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.UnrecognizedHTTPVersion: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_UnrecognizedHTTPVersion(detail), StatusCodes.Status505HttpVersionNotsupported, reason); + break; + case RequestRejectionReason.FinalTransferCodingNotChunked: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_FinalTransferCodingNotChunked(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.LengthRequired: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_LengthRequired(detail), StatusCodes.Status411LengthRequired, reason); + break; + case RequestRejectionReason.LengthRequiredHttp10: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_LengthRequiredHttp10(detail), StatusCodes.Status400BadRequest, reason); + break; + case RequestRejectionReason.InvalidHostHeader: + ex = new BadHttpRequestException(CoreStrings.FormatBadRequest_InvalidHostHeader_Detail(detail), StatusCodes.Status400BadRequest, reason); + break; + default: + ex = new BadHttpRequestException(CoreStrings.BadRequest, StatusCodes.Status400BadRequest, reason); + break; + } + return ex; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/ClientCertificateMode.cs b/src/Servers/Kestrel/Core/src/ClientCertificateMode.cs new file mode 100644 index 0000000000..caff5e041a --- /dev/null +++ b/src/Servers/Kestrel/Core/src/ClientCertificateMode.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Https +{ + /// + /// Describes the client certificate requirements for a HTTPS connection. + /// + public enum ClientCertificateMode + { + /// + /// A client certificate is not required and will not be requested from clients. + /// + NoCertificate, + + /// + /// A client certificate will be requested; however, authentication will not fail if a certificate is not provided by the client. + /// + AllowCertificate, + + /// + /// A client certificate will be requested, and the client must provide a valid certificate for authentication to succeed. + /// + RequireCertificate + } +} diff --git a/src/Servers/Kestrel/Core/src/CoreStrings.resx b/src/Servers/Kestrel/Core/src/CoreStrings.resx new file mode 100644 index 0000000000..ee80a16312 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/CoreStrings.resx @@ -0,0 +1,521 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Bad request. + + + Bad chunk size data. + + + Bad chunk suffix. + + + Chunked request incomplete. + + + The message body length cannot be determined because the final transfer coding was set to '{detail}' instead of 'chunked'. + + + Request headers too long. + + + Invalid characters in header name. + + + Invalid content length: {detail} + + + Invalid Host header. + + + Invalid Host header: '{detail}' + + + Invalid request headers: missing final CRLF in header fields. + + + Invalid request header: '{detail}' + + + Invalid request line. + + + Invalid request line: '{detail}' + + + Invalid request target: '{detail}' + + + {detail} request contains no Content-Length or Transfer-Encoding header. + + + {detail} request contains no Content-Length header. + + + Malformed request: invalid headers. + + + Method not allowed. + + + Request is missing Host header. + + + Multiple Content-Length headers. + + + Multiple Host headers. + + + Request line too long. + + + Reading the request headers timed out. + + + Request contains too many headers. + + + Unexpected end of request content. + + + Unrecognized HTTP version: '{detail}' + + + Requests with 'Connection: Upgrade' cannot have content in the request body. + + + Failed to bind to http://[::]:{port} (IPv6Any). Attempting to bind to http://0.0.0.0:{port} instead. + + + Cannot write to response body after connection has been upgraded. + + + Kestrel does not support big-endian architectures. + + + Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request header size ({requestHeaderSize}). + + + Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request line size ({requestLineSize}). + + + Server has already started. + + + Unknown transport mode: '{mode}'. + + + Invalid non-ASCII or control character in header: {character} + + + Invalid Content-Length: "{value}". Value must be a positive integral number. + + + Value must be null or a non-negative number. + + + Value must be a non-negative number. + + + Value must be a positive number. + + + Value must be null or a positive number. + + + Unix socket path must be absolute. + + + Failed to bind to address {address}. + + + No listening endpoints were configured. Binding to {address} by default. + + + HTTPS endpoints can only be configured using {methodName}. + + + A path base can only be configured using {methodName}. + + + Dynamic port binding is not supported when binding to localhost. You must either bind to 127.0.0.1:0 or [::1]:0, or both. + + + Failed to bind to address {endpoint}: address already in use. + + + Invalid URL: '{url}'. + + + Unable to bind to {address} on the {interfaceName} interface: '{error}'. + + + Overriding address(es) '{addresses}'. Binding to endpoints defined in {methodName} instead. + + + Overriding endpoints defined in UseKestrel() because {settingName} is set to true. Binding to address(es) '{addresses}' instead. + + + Unrecognized scheme in server address '{address}'. Only 'http://' is supported. + + + Headers are read-only, response has already started. + + + An item with the same key has already been added. + + + Setting the header {name} is not allowed on responses with status code {statusCode}. + + + {name} cannot be set because the response has already started. + + + Request processing didn't complete within the shutdown timeout. + + + Response Content-Length mismatch: too few bytes written ({written} of {expected}). + + + Response Content-Length mismatch: too many bytes written ({written} of {expected}). + + + The response has been aborted due to an unhandled application exception. + + + Writing to the response body is invalid for responses with status code {statusCode}. + + + Connection shutdown abnormally. + + + Connection processing ended abnormally. + + + Cannot upgrade a non-upgradable request. Check IHttpUpgradeFeature.IsUpgradableRequest to determine if a request can be upgraded. + + + Request cannot be upgraded because the server has already opened the maximum number of upgraded connections. + + + IHttpUpgradeFeature.UpgradeAsync was already called and can only be called once per connection. + + + Request body too large. + + + The maximum request body size cannot be modified after the app has already started reading from the request body. + + + The maximum request body size cannot be modified after the request has been upgraded. + + + Value must be a positive TimeSpan. + + + Value must be a non-negative TimeSpan. + + + The request body rate enforcement grace period must be greater than {heartbeatInterval} second. + + + Synchronous operations are disallowed. Call ReadAsync or set AllowSynchronousIO to true instead. + + + Synchronous operations are disallowed. Call WriteAsync or set AllowSynchronousIO to true instead. + + + Value must be a positive number. To disable a minimum data rate, use null where a MinDataRate instance is expected. + + + Concurrent timeouts are not supported. + + + Timespan must be positive and finite. + + + An endpoint must be configured to serve at least one protocol. + + + Using both HTTP/1.x and HTTP/2 on the same endpoint requires the use of TLS. + + + HTTP/2 over TLS was not negotiated on an HTTP/2-only endpoint. + + + A dynamic table size of {size} octets is greater than the configured maximum size of {maxSize} octets. + + + Index {index} is outside the bounds of the header field table. + + + Input data could not be fully decoded. + + + Input data contains the EOS symbol. + + + The destination buffer is not large enough to store the decoded data. + + + Huffman decoding error. + + + Decoded string length of {length} octets is greater than the configured maximum length of {maxStringLength} octets. + + + The header block was incomplete and could not be fully decoded. + + + The client sent a {frameType} frame with even stream ID {streamId}. + + + The client sent a A PUSH_PROMISE frame. + + + The client sent a {frameType} frame to stream ID {streamId} before signaling of the header block for stream ID {headersStreamId}. + + + The client sent a {frameType} frame with stream ID 0. + + + The client sent a {frameType} frame with stream ID different than 0. + + + The client sent a {frameType} frame with padding longer than or with the same length as the sent data. + + + The client sent a {frameType} frame to closed stream ID {streamId}. + + + The client sent a {frameType} frame to stream ID {streamId} which is in the "half-closed (remote) state". + + + The client sent a {frameType} frame with dependency information that would cause stream ID {streamId} to depend on itself. + + + The client sent a {frameType} frame with length different than {expectedLength}. + + + The client sent a SETTINGS frame with a length that is not a multiple of 6. + + + The client sent a SETTINGS frame with ACK set and length different than 0. + + + The client sent a SETTINGS frame with a value for parameter {parameter} that is out of range. + + + The client sent a WINDOW_UPDATE frame with a window size increment of 0. + + + The client sent a CONTINUATION frame not preceded by a HEADERS frame. + + + The client sent a {frameType} frame to idle stream ID {streamId}. + + + The client sent trailers containing one or more pseudo-header fields. + + + The client sent a header with uppercase characters in its name. + + + The client sent a trailer with uppercase characters in its name. + + + The client sent a HEADERS frame containing trailers without setting the END_STREAM flag. + + + Request headers missing one or more mandatory pseudo-header fields. + + + Pseudo-header field found in request headers after regular header fields. + + + Request headers contain unknown pseudo-header field. + + + Request headers contain response-specific pseudo-header field. + + + Request headers contain duplicate pseudo-header field. + + + Request headers contain connection-specific header field. + + + Unable to configure default https bindings because no IDefaultHttpsProvider service was provided. + + + Failed to authenticate HTTPS connection. + + + Authentication of the HTTPS connection timed out. + + + Certificate {thumbprint} cannot be used as an SSL server certificate. It has an Extended Key Usage extension but the usages do not include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + + + Value must be a positive TimeSpan. + + + The server certificate parameter is required. + + + No listening endpoints were configured. Binding to {address0} and {address1} by default. + + + The requested certificate {subject} could not be found in {storeLocation}/{storeName} with AllowInvalid setting: {allowInvalid}. + + + The endpoint {endpointName} is missing the required 'Url' parameter. + + + Unable to configure HTTPS endpoint. No server certificate was specified, and the default developer certificate could not be found. +To generate a developer certificate run 'dotnet dev-certs https'. To trust the certificate (Windows and macOS only) run 'dotnet dev-certs https --trust'. +For more information on configuring HTTPS see https://go.microsoft.com/fwlink/?linkid=848054. + + + The endpoint {endpointName} specified multiple certificate sources. + + + HTTP/2 support is experimental, see https://go.microsoft.com/fwlink/?linkid=866785 to enable it. + + + Cannot write to the response body, the response has completed. + + + Reading the request body timed out due to data arriving too slowly. See MinRequestBodyDataRate. + + + The connection was aborted by the application. + + + The connection was aborted because the server is shutting down and request processing didn't complete within the time specified by HostOptions.ShutdownTimeout. + + + The connection was timed out by the server because the response was not read by the client at the specified minimum data rate. + + + The connection was timed out by the server. + + \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/EndpointConfiguration.cs b/src/Servers/Kestrel/Core/src/EndpointConfiguration.cs new file mode 100644 index 0000000000..94848b14bd --- /dev/null +++ b/src/Servers/Kestrel/Core/src/EndpointConfiguration.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.AspNetCore.Server.Kestrel +{ + public class EndpointConfiguration + { + internal EndpointConfiguration(bool isHttps, ListenOptions listenOptions, HttpsConnectionAdapterOptions httpsOptions, IConfigurationSection configSection) + { + IsHttps = isHttps; + ListenOptions = listenOptions ?? throw new ArgumentNullException(nameof(listenOptions)); + HttpsOptions = httpsOptions ?? throw new ArgumentNullException(nameof(httpsOptions)); + ConfigSection = configSection ?? throw new ArgumentNullException(nameof(configSection)); + } + + public bool IsHttps { get; } + public ListenOptions ListenOptions { get; } + public HttpsConnectionAdapterOptions HttpsOptions { get; } + public IConfigurationSection ConfigSection { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/IConnectionTimeoutFeature.cs b/src/Servers/Kestrel/Core/src/Features/IConnectionTimeoutFeature.cs new file mode 100644 index 0000000000..e7634c3d88 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/IConnectionTimeoutFeature.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + /// + /// Feature for efficiently handling connection timeouts. + /// + public interface IConnectionTimeoutFeature + { + /// + /// Close the connection after the specified positive finite + /// unless the timeout is canceled or reset. This will fail if there is an ongoing timeout. + /// + void SetTimeout(TimeSpan timeSpan); + + /// + /// Close the connection after the specified positive finite + /// unless the timeout is canceled or reset. This will cancel any ongoing timeouts. + /// + void ResetTimeout(TimeSpan timeSpan); + + /// + /// Prevent the connection from closing after a timeout specified by + /// or . + /// + void CancelTimeout(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/IDecrementConcurrentConnectionCountFeature.cs b/src/Servers/Kestrel/Core/src/Features/IDecrementConcurrentConnectionCountFeature.cs new file mode 100644 index 0000000000..d34b1d1439 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/IDecrementConcurrentConnectionCountFeature.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + /// + /// A connection feature allowing middleware to stop counting connections towards . + /// This is used by Kestrel internally to stop counting upgraded connections towards this limit. + /// + public interface IDecrementConcurrentConnectionCountFeature + { + /// + /// Idempotent method to stop counting a connection towards . + /// + void ReleaseConnection(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/IHttp2StreamIdFeature.cs b/src/Servers/Kestrel/Core/src/Features/IHttp2StreamIdFeature.cs new file mode 100644 index 0000000000..30ad135062 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/IHttp2StreamIdFeature.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + public interface IHttp2StreamIdFeature + { + int StreamId { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/IHttpMinRequestBodyDataRateFeature.cs b/src/Servers/Kestrel/Core/src/Features/IHttpMinRequestBodyDataRateFeature.cs new file mode 100644 index 0000000000..f80bd99772 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/IHttpMinRequestBodyDataRateFeature.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + /// + /// Feature to set the minimum data rate at which the the request body must be sent by the client. + /// + public interface IHttpMinRequestBodyDataRateFeature + { + /// + /// The minimum data rate in bytes/second at which the request body must be sent by the client. + /// Setting this property to null indicates no minimum data rate should be enforced. + /// This limit has no effect on upgraded connections which are always unlimited. + /// + MinDataRate MinDataRate { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/IHttpMinResponseDataRateFeature.cs b/src/Servers/Kestrel/Core/src/Features/IHttpMinResponseDataRateFeature.cs new file mode 100644 index 0000000000..f901a338d9 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/IHttpMinResponseDataRateFeature.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + /// + /// Feature to set the minimum data rate at which the response must be received by the client. + /// + public interface IHttpMinResponseDataRateFeature + { + /// + /// The minimum data rate in bytes/second at which the response must be received by the client. + /// Setting this property to null indicates no minimum data rate should be enforced. + /// This limit has no effect on upgraded connections which are always unlimited. + /// + MinDataRate MinDataRate { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Features/ITlsApplicationProtocolFeature.cs b/src/Servers/Kestrel/Core/src/Features/ITlsApplicationProtocolFeature.cs new file mode 100644 index 0000000000..8adca3f0e8 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Features/ITlsApplicationProtocolFeature.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + public interface ITlsApplicationProtocolFeature + { + ReadOnlyMemory ApplicationProtocol { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/HttpProtocols.cs b/src/Servers/Kestrel/Core/src/HttpProtocols.cs new file mode 100644 index 0000000000..09524bf156 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/HttpProtocols.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + [Flags] + public enum HttpProtocols + { + None = 0x0, + Http1 = 0x1, + Http2 = 0x2, + Http1AndHttp2 = Http1 | Http2, + } +} diff --git a/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs new file mode 100644 index 0000000000..cf6bd88236 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/HttpsConnectionAdapterOptions.cs @@ -0,0 +1,94 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core; + +namespace Microsoft.AspNetCore.Server.Kestrel.Https +{ + /// + /// Settings for how Kestrel should handle HTTPS connections. + /// + public class HttpsConnectionAdapterOptions + { + private TimeSpan _handshakeTimeout; + + /// + /// Initializes a new instance of . + /// + public HttpsConnectionAdapterOptions() + { + ClientCertificateMode = ClientCertificateMode.NoCertificate; + SslProtocols = SslProtocols.Tls12 | SslProtocols.Tls11; + HandshakeTimeout = TimeSpan.FromSeconds(10); + } + + /// + /// + /// Specifies the server certificate used to authenticate HTTPS connections. This is ignored if ServerCertificateSelector is set. + /// + /// + /// If the server certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + /// + public X509Certificate2 ServerCertificate { get; set; } + + /// + /// + /// A callback that will be invoked to dynamically select a server certificate. This is higher priority than ServerCertificate. + /// If SNI is not avialable then the name parameter will be null. + /// + /// + /// If the server certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + /// + public Func ServerCertificateSelector { get; set; } + + /// + /// Specifies the client certificate requirements for a HTTPS connection. Defaults to . + /// + public ClientCertificateMode ClientCertificateMode { get; set; } + + /// + /// Specifies a callback for additional client certificate validation that will be invoked during authentication. + /// + public Func ClientCertificateValidation { get; set; } + + /// + /// Specifies allowable SSL protocols. Defaults to and . + /// + public SslProtocols SslProtocols { get; set; } + + /// + /// The protocols enabled on this endpoint. + /// + /// Defaults to HTTP/1.x only. + internal HttpProtocols HttpProtocols { get; set; } + + /// + /// Specifies whether the certificate revocation list is checked during authentication. + /// + public bool CheckCertificateRevocation { get; set; } + + /// + /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive and finite. + /// + public TimeSpan HandshakeTimeout + { + get => _handshakeTimeout; + set + { + if (value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveTimeSpanRequired); + } + _handshakeTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs new file mode 100644 index 0000000000..f4c1859b7f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class AddressBindContext + { + public ICollection Addresses { get; set; } + public List ListenOptions { get; set; } + public KestrelServerOptions ServerOptions { get; set; } + public ILogger Logger { get; set; } + + public Func CreateBinding { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs new file mode 100644 index 0000000000..67ff4f1108 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs @@ -0,0 +1,267 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class AddressBinder + { + public static async Task BindAsync(IServerAddressesFeature addresses, + KestrelServerOptions serverOptions, + ILogger logger, + Func createBinding) + { + var listenOptions = serverOptions.ListenOptions; + var strategy = CreateStrategy( + listenOptions.ToArray(), + addresses.Addresses.ToArray(), + addresses.PreferHostingUrls); + + var context = new AddressBindContext + { + Addresses = addresses.Addresses, + ListenOptions = listenOptions, + ServerOptions = serverOptions, + Logger = logger, + CreateBinding = createBinding + }; + + // reset options. The actual used options and addresses will be populated + // by the address binding feature + listenOptions.Clear(); + addresses.Addresses.Clear(); + + await strategy.BindAsync(context).ConfigureAwait(false); + } + + private static IStrategy CreateStrategy(ListenOptions[] listenOptions, string[] addresses, bool preferAddresses) + { + var hasListenOptions = listenOptions.Length > 0; + var hasAddresses = addresses.Length > 0; + + if (preferAddresses && hasAddresses) + { + if (hasListenOptions) + { + return new OverrideWithAddressesStrategy(addresses); + } + + return new AddressesStrategy(addresses); + } + else if (hasListenOptions) + { + if (hasAddresses) + { + return new OverrideWithEndpointsStrategy(listenOptions, addresses); + } + + return new EndpointsStrategy(listenOptions); + } + else if (hasAddresses) + { + // If no endpoints are configured directly using KestrelServerOptions, use those configured via the IServerAddressesFeature. + return new AddressesStrategy(addresses); + } + else + { + // "localhost" for both IPv4 and IPv6 can't be represented as an IPEndPoint. + return new DefaultAddressStrategy(); + } + } + + /// + /// Returns an for the given host an port. + /// If the host parameter isn't "localhost" or an IP address, use IPAddress.Any. + /// + protected internal static bool TryCreateIPEndPoint(ServerAddress address, out IPEndPoint endpoint) + { + if (!IPAddress.TryParse(address.Host, out var ip)) + { + endpoint = null; + return false; + } + + endpoint = new IPEndPoint(ip, address.Port); + return true; + } + + internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context) + { + try + { + await context.CreateBinding(endpoint).ConfigureAwait(false); + } + catch (AddressInUseException ex) + { + throw new IOException(CoreStrings.FormatEndpointAlreadyInUse(endpoint), ex); + } + + context.ListenOptions.Add(endpoint); + } + + internal static ListenOptions ParseAddress(string address, out bool https) + { + var parsedAddress = ServerAddress.FromUrl(address); + https = false; + + if (parsedAddress.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + https = true; + } + else if (!parsedAddress.Scheme.Equals("http", StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException(CoreStrings.FormatUnsupportedAddressScheme(address)); + } + + if (!string.IsNullOrEmpty(parsedAddress.PathBase)) + { + throw new InvalidOperationException(CoreStrings.FormatConfigurePathBaseFromMethodCall($"{nameof(IApplicationBuilder)}.UsePathBase()")); + } + + ListenOptions options = null; + if (parsedAddress.IsUnixPipe) + { + options = new ListenOptions(parsedAddress.UnixPipePath); + } + else if (string.Equals(parsedAddress.Host, "localhost", StringComparison.OrdinalIgnoreCase)) + { + // "localhost" for both IPv4 and IPv6 can't be represented as an IPEndPoint. + options = new LocalhostListenOptions(parsedAddress.Port); + } + else if (TryCreateIPEndPoint(parsedAddress, out var endpoint)) + { + options = new ListenOptions(endpoint); + } + else + { + // when address is 'http://hostname:port', 'http://*:port', or 'http://+:port' + options = new AnyIPListenOptions(parsedAddress.Port); + } + + return options; + } + + private interface IStrategy + { + Task BindAsync(AddressBindContext context); + } + + private class DefaultAddressStrategy : IStrategy + { + public async Task BindAsync(AddressBindContext context) + { + var httpDefault = ParseAddress(Constants.DefaultServerAddress, out var https); + context.ServerOptions.ApplyEndpointDefaults(httpDefault); + await httpDefault.BindAsync(context).ConfigureAwait(false); + + // Conditional https default, only if a cert is available + var httpsDefault = ParseAddress(Constants.DefaultServerHttpsAddress, out https); + context.ServerOptions.ApplyEndpointDefaults(httpsDefault); + + if (httpsDefault.ConnectionAdapters.Any(f => f.IsHttps) + || httpsDefault.TryUseHttps()) + { + await httpsDefault.BindAsync(context).ConfigureAwait(false); + context.Logger.LogDebug(CoreStrings.BindingToDefaultAddresses, + Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress); + } + else + { + // No default cert is available, do not bind to the https endpoint. + context.Logger.LogDebug(CoreStrings.BindingToDefaultAddress, Constants.DefaultServerAddress); + } + } + } + + private class OverrideWithAddressesStrategy : AddressesStrategy + { + public OverrideWithAddressesStrategy(IReadOnlyCollection addresses) + : base(addresses) + { + } + + public override Task BindAsync(AddressBindContext context) + { + var joined = string.Join(", ", _addresses); + context.Logger.LogInformation(CoreStrings.OverridingWithPreferHostingUrls, nameof(IServerAddressesFeature.PreferHostingUrls), joined); + + return base.BindAsync(context); + } + } + + private class OverrideWithEndpointsStrategy : EndpointsStrategy + { + private readonly string[] _originalAddresses; + + public OverrideWithEndpointsStrategy(IReadOnlyCollection endpoints, string[] originalAddresses) + : base(endpoints) + { + _originalAddresses = originalAddresses; + } + + public override Task BindAsync(AddressBindContext context) + { + var joined = string.Join(", ", _originalAddresses); + context.Logger.LogWarning(CoreStrings.OverridingWithKestrelOptions, joined, "UseKestrel()"); + + return base.BindAsync(context); + } + } + + private class EndpointsStrategy : IStrategy + { + private readonly IReadOnlyCollection _endpoints; + + public EndpointsStrategy(IReadOnlyCollection endpoints) + { + _endpoints = endpoints; + } + + public virtual async Task BindAsync(AddressBindContext context) + { + foreach (var endpoint in _endpoints) + { + await endpoint.BindAsync(context).ConfigureAwait(false); + } + } + } + + private class AddressesStrategy : IStrategy + { + protected readonly IReadOnlyCollection _addresses; + + public AddressesStrategy(IReadOnlyCollection addresses) + { + _addresses = addresses; + } + + public virtual async Task BindAsync(AddressBindContext context) + { + foreach (var address in _addresses) + { + var options = ParseAddress(address, out var https); + context.ServerOptions.ApplyEndpointDefaults(options); + + if (https && !options.ConnectionAdapters.Any(f => f.IsHttps)) + { + options.UseHttps(); + } + + await options.BindAsync(context).ConfigureAwait(false); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/BufferReader.cs b/src/Servers/Kestrel/Core/src/Internal/BufferReader.cs new file mode 100644 index 0000000000..5a1e75b69b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/BufferReader.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + internal ref struct BufferReader + { + private ReadOnlySpan _currentSpan; + private int _index; + + private ReadOnlySequence _sequence; + private SequencePosition _currentSequencePosition; + private SequencePosition _nextSequencePosition; + + private int _consumedBytes; + private bool _end; + + public BufferReader(ReadOnlySequence buffer) + { + _end = false; + _index = 0; + _consumedBytes = 0; + _sequence = buffer; + _currentSequencePosition = _sequence.Start; + _nextSequencePosition = _currentSequencePosition; + _currentSpan = ReadOnlySpan.Empty; + MoveNext(); + } + + public bool End => _end; + + public int CurrentSegmentIndex => _index; + + public SequencePosition Position => _sequence.GetPosition(_index, _currentSequencePosition); + + public ReadOnlySpan CurrentSegment => _currentSpan; + + public ReadOnlySpan UnreadSegment => _currentSpan.Slice(_index); + + public int ConsumedBytes => _consumedBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Peek() + { + if (_end) + { + return -1; + } + return _currentSpan[_index]; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Read() + { + if (_end) + { + return -1; + } + + var value = _currentSpan[_index]; + _index++; + _consumedBytes++; + + if (_index >= _currentSpan.Length) + { + MoveNext(); + } + + return value; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void MoveNext() + { + var previous = _nextSequencePosition; + while (_sequence.TryGet(ref _nextSequencePosition, out var memory, true)) + { + _currentSequencePosition = previous; + _currentSpan = memory.Span; + _index = 0; + if (_currentSpan.Length > 0) + { + return; + } + } + _end = true; + } + + public void Advance(int byteCount) + { + if (byteCount < 0) + { + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); + } + + _consumedBytes += byteCount; + + while (!_end && byteCount > 0) + { + if ((_index + byteCount) < _currentSpan.Length) + { + _index += byteCount; + byteCount = 0; + break; + } + + var remaining = (_currentSpan.Length - _index); + + _index += remaining; + byteCount -= remaining; + + MoveNext(); + } + + if (byteCount > 0) + { + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/BufferWriter.cs b/src/Servers/Kestrel/Core/src/Internal/BufferWriter.cs new file mode 100644 index 0000000000..1f33f3e4cb --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/BufferWriter.cs @@ -0,0 +1,96 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + internal ref struct BufferWriter where T : IBufferWriter + { + private T _output; + private Span _span; + private int _buffered; + private long _bytesCommitted; + + public BufferWriter(T output) + { + _buffered = 0; + _bytesCommitted = 0; + _output = output; + _span = output.GetSpan(); + } + + public Span Span => _span; + public long BytesCommitted => _bytesCommitted; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Commit() + { + var buffered = _buffered; + if (buffered > 0) + { + _bytesCommitted += buffered; + _buffered = 0; + _output.Advance(buffered); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Advance(int count) + { + _buffered += count; + _span = _span.Slice(count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Write(ReadOnlySpan source) + { + if (_span.Length >= source.Length) + { + source.CopyTo(_span); + Advance(source.Length); + } + else + { + WriteMultiBuffer(source); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Ensure(int count = 1) + { + if (_span.Length < count) + { + EnsureMore(count); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void EnsureMore(int count = 0) + { + if (_buffered > 0) + { + Commit(); + } + + _output.GetMemory(count); + _span = _output.GetSpan(); + } + + private void WriteMultiBuffer(ReadOnlySpan source) + { + while (source.Length > 0) + { + if (_span.Length == 0) + { + EnsureMore(); + } + + var writable = Math.Min(source.Length, _span.Length); + source.Slice(0, writable).CopyTo(_span); + source = source.Slice(writable); + Advance(writable); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/CertificateLoader.cs b/src/Servers/Kestrel/Core/src/Internal/CertificateLoader.cs new file mode 100644 index 0000000000..ce9c17e340 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/CertificateLoader.cs @@ -0,0 +1,97 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Server.Kestrel.Core; + +namespace Microsoft.AspNetCore.Server.Kestrel.Https.Internal +{ + public static class CertificateLoader + { + // See http://oid-info.com/get/1.3.6.1.5.5.7.3.1 + // Indicates that a certificate can be used as a SSL server certificate + private const string ServerAuthenticationOid = "1.3.6.1.5.5.7.3.1"; + + public static X509Certificate2 LoadFromStoreCert(string subject, string storeName, StoreLocation storeLocation, bool allowInvalid) + { + using (var store = new X509Store(storeName, storeLocation)) + { + X509Certificate2Collection storeCertificates = null; + X509Certificate2 foundCertificate = null; + + try + { + store.Open(OpenFlags.ReadOnly); + storeCertificates = store.Certificates; + var foundCertificates = storeCertificates.Find(X509FindType.FindBySubjectName, subject, !allowInvalid); + foundCertificate = foundCertificates + .OfType() + .Where(IsCertificateAllowedForServerAuth) + .OrderByDescending(certificate => certificate.NotAfter) + .FirstOrDefault(); + + if (foundCertificate == null) + { + throw new InvalidOperationException(CoreStrings.FormatCertNotFoundInStore(subject, storeLocation, storeName, allowInvalid)); + } + + return foundCertificate; + } + finally + { + DisposeCertificates(storeCertificates, except: foundCertificate); + } + } + } + + internal static bool IsCertificateAllowedForServerAuth(X509Certificate2 certificate) + { + /* If the Extended Key Usage extension is included, then we check that the serverAuth usage is included. (http://oid-info.com/get/1.3.6.1.5.5.7.3.1) + * If the Extended Key Usage extension is not included, then we assume the certificate is allowed for all usages. + * + * See also https://blogs.msdn.microsoft.com/kaushal/2012/02/17/client-certificates-vs-server-certificates/ + * + * From https://tools.ietf.org/html/rfc3280#section-4.2.1.13 "Certificate Extensions: Extended Key Usage" + * + * If the (Extended Key Usage) extension is present, then the certificate MUST only be used + * for one of the purposes indicated. If multiple purposes are + * indicated the application need not recognize all purposes indicated, + * as long as the intended purpose is present. Certificate using + * applications MAY require that a particular purpose be indicated in + * order for the certificate to be acceptable to that application. + */ + + var hasEkuExtension = false; + + foreach (var extension in certificate.Extensions.OfType()) + { + hasEkuExtension = true; + foreach (var oid in extension.EnhancedKeyUsages) + { + if (oid.Value.Equals(ServerAuthenticationOid, StringComparison.Ordinal)) + { + return true; + } + } + } + + return !hasEkuExtension; + } + + private static void DisposeCertificates(X509Certificate2Collection certificates, X509Certificate2 except) + { + if (certificates != null) + { + foreach (var certificate in certificates) + { + if (!certificate.Equals(except)) + { + certificate.Dispose(); + } + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ClosedStream.cs b/src/Servers/Kestrel/Core/src/Internal/ClosedStream.cs new file mode 100644 index 0000000000..a744faf0ec --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ClosedStream.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Https.Internal +{ + internal class ClosedStream : Stream + { + private static readonly Task ZeroResultTask = Task.FromResult(result: 0); + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + + public override long Length + { + get + { + throw new NotSupportedException(); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(); + } + set + { + throw new NotSupportedException(); + } + } + + public override void Flush() + { + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return 0; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ZeroResultTask; + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs b/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs new file mode 100644 index 0000000000..b36a9af94e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs @@ -0,0 +1,139 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class ConfigurationReader + { + private IConfiguration _configuration; + private IDictionary _certificates; + private IList _endpoints; + + public ConfigurationReader(IConfiguration configuration) + { + _configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + } + + public IDictionary Certificates + { + get + { + if (_certificates == null) + { + ReadCertificates(); + } + + return _certificates; + } + } + + public IEnumerable Endpoints + { + get + { + if (_endpoints == null) + { + ReadEndpoints(); + } + + return _endpoints; + } + } + + private void ReadCertificates() + { + _certificates = new Dictionary(0); + + var certificatesConfig = _configuration.GetSection("Certificates").GetChildren(); + foreach (var certificateConfig in certificatesConfig) + { + _certificates.Add(certificateConfig.Key, new CertificateConfig(certificateConfig)); + } + } + + private void ReadEndpoints() + { + _endpoints = new List(); + + var endpointsConfig = _configuration.GetSection("Endpoints").GetChildren(); + foreach (var endpointConfig in endpointsConfig) + { + // "EndpointName": { +        // "Url": "https://*:5463", +        // "Certificate": { +          // "Path": "testCert.pfx", +          // "Password": "testPassword" +       // } + // } + + var url = endpointConfig["Url"]; + if (string.IsNullOrEmpty(url)) + { + throw new InvalidOperationException(CoreStrings.FormatEndpointMissingUrl(endpointConfig.Key)); + } + + var endpoint = new EndpointConfig() + { + Name = endpointConfig.Key, + Url = url, + ConfigSection = endpointConfig, + Certificate = new CertificateConfig(endpointConfig.GetSection("Certificate")), + }; + _endpoints.Add(endpoint); + } + } + } + + // "EndpointName": { + // "Url": "https://*:5463", + // "Certificate": { + // "Path": "testCert.pfx", + // "Password": "testPassword" + // } + // } + internal class EndpointConfig + { + public string Name { get; set; } + public string Url { get; set; } + public IConfigurationSection ConfigSection { get; set; } + public CertificateConfig Certificate { get; set; } + } + + // "CertificateName": { + // "Path": "testCert.pfx", + // "Password": "testPassword" + // } + internal class CertificateConfig + { + public CertificateConfig(IConfigurationSection configSection) + { + ConfigSection = configSection; + ConfigSection.Bind(this); + } + + public IConfigurationSection ConfigSection { get; } + + // File + public bool IsFileCert => !string.IsNullOrEmpty(Path); + + public string Path { get; set; } + + public string Password { get; set; } + + // Cert store + + public bool IsStoreCert => !string.IsNullOrEmpty(Subject); + + public string Subject { get; set; } + + public string Store { get; set; } + + public string Location { get; set; } + + public bool? AllowInvalid { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs b/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs new file mode 100644 index 0000000000..54b2d4a7e4 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs @@ -0,0 +1,113 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class ConnectionDispatcher : IConnectionDispatcher + { + private readonly ServiceContext _serviceContext; + private readonly ConnectionDelegate _connectionDelegate; + + public ConnectionDispatcher(ServiceContext serviceContext, ConnectionDelegate connectionDelegate) + { + _serviceContext = serviceContext; + _connectionDelegate = connectionDelegate; + } + + private IKestrelTrace Log => _serviceContext.Log; + + public void OnConnection(TransportConnection connection) + { + // REVIEW: Unfortunately, we still need to use the service context to create the pipes since the settings + // for the scheduler and limits are specified here + var inputOptions = GetInputPipeOptions(_serviceContext, connection.MemoryPool, connection.InputWriterScheduler); + var outputOptions = GetOutputPipeOptions(_serviceContext, connection.MemoryPool, connection.OutputReaderScheduler); + + var pair = DuplexPipe.CreateConnectionPair(inputOptions, outputOptions); + + // Set the transport and connection id + connection.ConnectionId = CorrelationIdGenerator.GetNextId(); + connection.Transport = pair.Transport; + + // This *must* be set before returning from OnConnection + connection.Application = pair.Application; + + // REVIEW: This task should be tracked by the server for graceful shutdown + // Today it's handled specifically for http but not for aribitrary middleware + _ = Execute(connection); + } + + private async Task Execute(ConnectionContext connectionContext) + { + using (BeginConnectionScope(connectionContext)) + { + try + { + await _connectionDelegate(connectionContext); + } + catch (Exception ex) + { + Log.LogCritical(0, ex, $"{nameof(ConnectionDispatcher)}.{nameof(Execute)}() {connectionContext.ConnectionId}"); + } + } + } + + private IDisposable BeginConnectionScope(ConnectionContext connectionContext) + { + if (Log.IsEnabled(LogLevel.Critical)) + { + return Log.BeginScope(new ConnectionLogScope(connectionContext.ConnectionId)); + } + + return null; + } + + // Internal for testing + internal static PipeOptions GetInputPipeOptions(ServiceContext serviceContext, MemoryPool memoryPool, PipeScheduler writerScheduler) => new PipeOptions + ( + pool: memoryPool, + readerScheduler: serviceContext.Scheduler, + writerScheduler: writerScheduler, + pauseWriterThreshold: serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + resumeWriterThreshold: serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + ); + + internal static PipeOptions GetOutputPipeOptions(ServiceContext serviceContext, MemoryPool memoryPool, PipeScheduler readerScheduler) => new PipeOptions + ( + pool: memoryPool, + readerScheduler: readerScheduler, + writerScheduler: serviceContext.Scheduler, + pauseWriterThreshold: GetOutputResponseBufferSize(serviceContext), + resumeWriterThreshold: GetOutputResponseBufferSize(serviceContext), + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + ); + + private static long GetOutputResponseBufferSize(ServiceContext serviceContext) + { + var bufferSize = serviceContext.ServerOptions.Limits.MaxResponseBufferSize; + if (bufferSize == 0) + { + // 0 = no buffering so we need to configure the pipe so the the writer waits on the reader directly + return 1; + } + + // null means that we have no back pressure + return bufferSize ?? 0; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ConnectionLimitMiddleware.cs b/src/Servers/Kestrel/Core/src/Internal/ConnectionLimitMiddleware.cs new file mode 100644 index 0000000000..5be3d5479a --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ConnectionLimitMiddleware.cs @@ -0,0 +1,74 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class ConnectionLimitMiddleware + { + private readonly ConnectionDelegate _next; + private readonly ResourceCounter _concurrentConnectionCounter; + private readonly IKestrelTrace _trace; + + public ConnectionLimitMiddleware(ConnectionDelegate next, long connectionLimit, IKestrelTrace trace) + : this(next, ResourceCounter.Quota(connectionLimit), trace) + { + } + + // For Testing + internal ConnectionLimitMiddleware(ConnectionDelegate next, ResourceCounter concurrentConnectionCounter, IKestrelTrace trace) + { + _next = next; + _concurrentConnectionCounter = concurrentConnectionCounter; + _trace = trace; + } + + public async Task OnConnectionAsync(ConnectionContext connection) + { + if (!_concurrentConnectionCounter.TryLockOne()) + { + KestrelEventSource.Log.ConnectionRejected(connection.ConnectionId); + _trace.ConnectionRejected(connection.ConnectionId); + connection.Transport.Input.Complete(); + connection.Transport.Output.Complete(); + return; + } + + var releasor = new ConnectionReleasor(_concurrentConnectionCounter); + + try + { + connection.Features.Set(releasor); + await _next(connection); + } + finally + { + releasor.ReleaseConnection(); + } + } + + private class ConnectionReleasor : IDecrementConcurrentConnectionCountFeature + { + private readonly ResourceCounter _concurrentConnectionCounter; + private bool _connectionReleased; + + public ConnectionReleasor(ResourceCounter normalConnectionCounter) + { + _concurrentConnectionCounter = normalConnectionCounter; + } + + public void ReleaseConnection() + { + if (!_connectionReleased) + { + _connectionReleased = true; + _concurrentConnectionCounter.ReleaseOne(); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ConnectionLogScope.cs b/src/Servers/Kestrel/Core/src/Internal/ConnectionLogScope.cs new file mode 100644 index 0000000000..3d1b882e82 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ConnectionLogScope.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Globalization; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class ConnectionLogScope : IReadOnlyList> + { + private readonly string _connectionId; + + private string _cachedToString; + + public ConnectionLogScope(string connectionId) + { + _connectionId = connectionId; + } + + public KeyValuePair this[int index] + { + get + { + if (index == 0) + { + return new KeyValuePair("ConnectionId", _connectionId); + } + + throw new ArgumentOutOfRangeException(nameof(index)); + } + } + + public int Count => 1; + + public IEnumerator> GetEnumerator() + { + for (int i = 0; i < Count; ++i) + { + yield return this[i]; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public override string ToString() + { + if (_cachedToString == null) + { + _cachedToString = string.Format( + CultureInfo.InvariantCulture, + "ConnectionId:{0}", + _connectionId); + } + + return _cachedToString; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/DuplexPipe.cs b/src/Servers/Kestrel/Core/src/Internal/DuplexPipe.cs new file mode 100644 index 0000000000..adc18c8c7d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/DuplexPipe.cs @@ -0,0 +1,41 @@ +using System.Buffers; + +namespace System.IO.Pipelines +{ + internal class DuplexPipe : IDuplexPipe + { + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public readonly struct DuplexPipePair + { + public IDuplexPipe Transport { get; } + public IDuplexPipe Application { get; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/ChunkWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http/ChunkWriter.cs new file mode 100644 index 0000000000..3d8cc4566b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/ChunkWriter.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + internal static class ChunkWriter + { + private static readonly ArraySegment _endChunkBytes = CreateAsciiByteArraySegment("\r\n"); + private static readonly byte[] _hex = Encoding.ASCII.GetBytes("0123456789abcdef"); + + private static ArraySegment CreateAsciiByteArraySegment(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + return new ArraySegment(bytes); + } + + public static ArraySegment BeginChunkBytes(int dataCount) + { + var bytes = new byte[10] + { + _hex[((dataCount >> 0x1c) & 0x0f)], + _hex[((dataCount >> 0x18) & 0x0f)], + _hex[((dataCount >> 0x14) & 0x0f)], + _hex[((dataCount >> 0x10) & 0x0f)], + _hex[((dataCount >> 0x0c) & 0x0f)], + _hex[((dataCount >> 0x08) & 0x0f)], + _hex[((dataCount >> 0x04) & 0x0f)], + _hex[((dataCount >> 0x00) & 0x0f)], + (byte)'\r', + (byte)'\n', + }; + + // Determine the most-significant non-zero nibble + int total, shift; + total = (dataCount > 0xffff) ? 0x10 : 0x00; + dataCount >>= total; + shift = (dataCount > 0x00ff) ? 0x08 : 0x00; + dataCount >>= shift; + total |= shift; + total |= (dataCount > 0x000f) ? 0x04 : 0x00; + + var offset = 7 - (total >> 2); + return new ArraySegment(bytes, offset, 10 - offset); + } + + internal static int WriteBeginChunkBytes(ref BufferWriter start, int dataCount) + { + var chunkSegment = BeginChunkBytes(dataCount); + start.Write(new ReadOnlySpan(chunkSegment.Array, chunkSegment.Offset, chunkSegment.Count)); + return chunkSegment.Count; + } + + internal static void WriteEndChunkBytes(ref BufferWriter start) + { + start.Write(new ReadOnlySpan(_endChunkBytes.Array, _endChunkBytes.Offset, _endChunkBytes.Count)); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/ConnectionOptions.cs b/src/Servers/Kestrel/Core/src/Internal/Http/ConnectionOptions.cs new file mode 100644 index 0000000000..71817aed69 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/ConnectionOptions.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + [Flags] + public enum ConnectionOptions + { + None = 0, + Close = 1, + KeepAlive = 2, + Upgrade = 4 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/DateHeaderValueManager.cs b/src/Servers/Kestrel/Core/src/Internal/Http/DateHeaderValueManager.cs new file mode 100644 index 0000000000..b2cb874364 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/DateHeaderValueManager.cs @@ -0,0 +1,73 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + /// + /// Manages the generation of the date header value. + /// + public class DateHeaderValueManager : IHeartbeatHandler + { + private static readonly byte[] _datePreambleBytes = Encoding.ASCII.GetBytes("\r\nDate: "); + + private DateHeaderValues _dateValues; + + /// + /// Initializes a new instance of the class. + /// + public DateHeaderValueManager() + : this(systemClock: new SystemClock()) + { + } + + // Internal for testing + internal DateHeaderValueManager(ISystemClock systemClock) + { + SetDateValues(systemClock.UtcNow); + } + + /// + /// Returns a value representing the current server date/time for use in the HTTP "Date" response header + /// in accordance with http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.18 + /// + /// The value in string and byte[] format. + public DateHeaderValues GetDateHeaderValues() => _dateValues; + + // Called by the Timer (background) thread + public void OnHeartbeat(DateTimeOffset now) + { + SetDateValues(now); + } + + /// + /// Sets date values from a provided ticks value + /// + /// A DateTimeOffset value + private void SetDateValues(DateTimeOffset value) + { + var dateValue = HeaderUtilities.FormatDate(value); + var dateBytes = new byte[_datePreambleBytes.Length + dateValue.Length]; + Buffer.BlockCopy(_datePreambleBytes, 0, dateBytes, 0, _datePreambleBytes.Length); + Encoding.ASCII.GetBytes(dateValue, 0, dateValue.Length, dateBytes, _datePreambleBytes.Length); + + var dateValues = new DateHeaderValues() + { + Bytes = dateBytes, + String = dateValue + }; + Volatile.Write(ref _dateValues, dateValues); + } + + public class DateHeaderValues + { + public byte[] Bytes; + public string String; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs new file mode 100644 index 0000000000..9264a75df7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs @@ -0,0 +1,48 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class Http1Connection : IHttpUpgradeFeature + { + bool IHttpUpgradeFeature.IsUpgradableRequest => IsUpgradableRequest; + + async Task IHttpUpgradeFeature.UpgradeAsync() + { + if (!((IHttpUpgradeFeature)this).IsUpgradableRequest) + { + throw new InvalidOperationException(CoreStrings.CannotUpgradeNonUpgradableRequest); + } + + if (IsUpgraded) + { + throw new InvalidOperationException(CoreStrings.UpgradeCannotBeCalledMultipleTimes); + } + + if (!ServiceContext.ConnectionManager.UpgradedConnectionCount.TryLockOne()) + { + throw new InvalidOperationException(CoreStrings.UpgradedConnectionLimitReached); + } + + IsUpgraded = true; + + ConnectionFeatures.Get()?.ReleaseConnection(); + + StatusCode = StatusCodes.Status101SwitchingProtocols; + ReasonPhrase = "Switching Protocols"; + ResponseHeaders["Connection"] = "Upgrade"; + + await FlushAsync(default(CancellationToken)); + + return _streams.Upgrade(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs new file mode 100644 index 0000000000..61eac8a37c --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs @@ -0,0 +1,527 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Globalization; +using System.IO.Pipelines; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Connections.Abstractions; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class Http1Connection : HttpProtocol, IRequestProcessor + { + private const byte ByteAsterisk = (byte)'*'; + private const byte ByteForwardSlash = (byte)'/'; + private const string Asterisk = "*"; + + private readonly Http1ConnectionContext _context; + private readonly IHttpParser _parser; + protected readonly long _keepAliveTicks; + private readonly long _requestHeadersTimeoutTicks; + + private volatile bool _requestTimedOut; + private uint _requestCount; + + private HttpRequestTarget _requestTargetForm = HttpRequestTarget.Unknown; + private Uri _absoluteRequestTarget; + + private int _remainingRequestHeadersBytesAllowed; + + public Http1Connection(Http1ConnectionContext context) + : base(context) + { + _context = context; + _parser = ServiceContext.HttpParser; + _keepAliveTicks = ServerOptions.Limits.KeepAliveTimeout.Ticks; + _requestHeadersTimeoutTicks = ServerOptions.Limits.RequestHeadersTimeout.Ticks; + + Output = new Http1OutputProducer( + _context.Transport.Output, + _context.ConnectionId, + _context.ConnectionContext, + _context.ServiceContext.Log, + _context.TimeoutControl, + _context.ConnectionFeatures.Get()); + } + + public PipeReader Input => _context.Transport.Input; + + public ITimeoutControl TimeoutControl => _context.TimeoutControl; + public bool RequestTimedOut => _requestTimedOut; + + public override bool IsUpgradableRequest => _upgradeAvailable; + + /// + /// Stops the request processing loop between requests. + /// Called on all active connections when the server wants to initiate a shutdown + /// and after a keep-alive timeout. + /// + public void StopProcessingNextRequest() + { + _keepAlive = false; + Input.CancelPendingRead(); + } + + public void SendTimeoutResponse() + { + _requestTimedOut = true; + Input.CancelPendingRead(); + } + + public void ParseRequest(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.End; + + switch (_requestProcessingStatus) + { + case RequestProcessingStatus.RequestPending: + if (buffer.IsEmpty) + { + break; + } + + TimeoutControl.ResetTimeout(_requestHeadersTimeoutTicks, TimeoutAction.SendTimeoutResponse); + + _requestProcessingStatus = RequestProcessingStatus.ParsingRequestLine; + goto case RequestProcessingStatus.ParsingRequestLine; + case RequestProcessingStatus.ParsingRequestLine: + if (TakeStartLine(buffer, out consumed, out examined)) + { + buffer = buffer.Slice(consumed, buffer.End); + + _requestProcessingStatus = RequestProcessingStatus.ParsingHeaders; + goto case RequestProcessingStatus.ParsingHeaders; + } + else + { + break; + } + case RequestProcessingStatus.ParsingHeaders: + if (TakeMessageHeaders(buffer, out consumed, out examined)) + { + _requestProcessingStatus = RequestProcessingStatus.AppStarted; + } + break; + } + } + + public bool TakeStartLine(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + var overLength = false; + if (buffer.Length >= ServerOptions.Limits.MaxRequestLineSize) + { + buffer = buffer.Slice(buffer.Start, ServerOptions.Limits.MaxRequestLineSize); + overLength = true; + } + + var result = _parser.ParseRequestLine(new Http1ParsingHandler(this), buffer, out consumed, out examined); + if (!result && overLength) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestLineTooLong); + } + + return result; + } + + public bool TakeMessageHeaders(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + // Make sure the buffer is limited + bool overLength = false; + if (buffer.Length >= _remainingRequestHeadersBytesAllowed) + { + buffer = buffer.Slice(buffer.Start, _remainingRequestHeadersBytesAllowed); + + // If we sliced it means the current buffer bigger than what we're + // allowed to look at + overLength = true; + } + + var result = _parser.ParseHeaders(new Http1ParsingHandler(this), buffer, out consumed, out examined, out var consumedBytes); + _remainingRequestHeadersBytesAllowed -= consumedBytes; + + if (!result && overLength) + { + BadHttpRequestException.Throw(RequestRejectionReason.HeadersExceedMaxTotalSize); + } + if (result) + { + TimeoutControl.CancelTimeout(); + } + + return result; + } + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + Debug.Assert(target.Length != 0, "Request target must be non-zero length"); + + var ch = target[0]; + if (ch == ByteForwardSlash) + { + // origin-form. + // The most common form of request-target. + // https://tools.ietf.org/html/rfc7230#section-5.3.1 + OnOriginFormTarget(method, version, target, path, query, customMethod, pathEncoded); + } + else if (ch == ByteAsterisk && target.Length == 1) + { + OnAsteriskFormTarget(method); + } + else if (target.GetKnownHttpScheme(out var scheme)) + { + OnAbsoluteFormTarget(target, query); + } + else + { + // Assume anything else is considered authority form. + // FYI: this should be an edge case. This should only happen when + // a client mistakenly thinks this server is a proxy server. + OnAuthorityFormTarget(method, target); + } + + Method = method; + if (method == HttpMethod.Custom) + { + _methodText = customMethod.GetAsciiStringNonNullCharacters(); + } + + _httpVersion = version; + + Debug.Assert(RawTarget != null, "RawTarget was not set"); + Debug.Assert(((IHttpRequestFeature)this).Method != null, "Method was not set"); + Debug.Assert(Path != null, "Path was not set"); + Debug.Assert(QueryString != null, "QueryString was not set"); + Debug.Assert(HttpVersion != null, "HttpVersion was not set"); + } + + private void OnOriginFormTarget(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + Debug.Assert(target[0] == ByteForwardSlash, "Should only be called when path starts with /"); + + _requestTargetForm = HttpRequestTarget.OriginForm; + + // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 + // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; + // then encoded/escaped to ASCII https://www.ietf.org/rfc/rfc3987.txt "Mapping of IRIs to URIs" + string requestUrlPath = null; + string rawTarget = null; + + try + { + // Read raw target before mutating memory. + rawTarget = target.GetAsciiStringNonNullCharacters(); + + if (pathEncoded) + { + // URI was encoded, unescape and then parse as UTF-8 + // Disabling warning temporary + var pathLength = UrlDecoder.Decode(path, path); + + // Removing dot segments must be done after unescaping. From RFC 3986: + // + // URI producing applications should percent-encode data octets that + // correspond to characters in the reserved set unless these characters + // are specifically allowed by the URI scheme to represent data in that + // component. If a reserved character is found in a URI component and + // no delimiting role is known for that character, then it must be + // interpreted as representing the data octet corresponding to that + // character's encoding in US-ASCII. + // + // https://tools.ietf.org/html/rfc3986#section-2.2 + pathLength = PathNormalizer.RemoveDotSegments(path.Slice(0, pathLength)); + + requestUrlPath = GetUtf8String(path.Slice(0, pathLength)); + } + else + { + var pathLength = PathNormalizer.RemoveDotSegments(path); + + if (path.Length == pathLength && query.Length == 0) + { + // If no decoding was required, no dot segments were removed and + // there is no query, the request path is the same as the raw target + requestUrlPath = rawTarget; + } + else + { + requestUrlPath = path.Slice(0, pathLength).GetAsciiStringNonNullCharacters(); + } + } + } + catch (InvalidOperationException) + { + ThrowRequestTargetRejected(target); + } + + QueryString = query.GetAsciiStringNonNullCharacters(); + RawTarget = rawTarget; + Path = requestUrlPath; + } + + private void OnAuthorityFormTarget(HttpMethod method, Span target) + { + _requestTargetForm = HttpRequestTarget.AuthorityForm; + + // This is not complete validation. It is just a quick scan for invalid characters + // but doesn't check that the target fully matches the URI spec. + for (var i = 0; i < target.Length; i++) + { + var ch = target[i]; + if (!UriUtilities.IsValidAuthorityCharacter(ch)) + { + ThrowRequestTargetRejected(target); + } + } + + // The authority-form of request-target is only used for CONNECT + // requests (https://tools.ietf.org/html/rfc7231#section-4.3.6). + if (method != HttpMethod.Connect) + { + BadHttpRequestException.Throw(RequestRejectionReason.ConnectMethodRequired); + } + + // When making a CONNECT request to establish a tunnel through one or + // more proxies, a client MUST send only the target URI's authority + // component (excluding any userinfo and its "@" delimiter) as the + // request-target.For example, + // + // CONNECT www.example.com:80 HTTP/1.1 + // + // Allowed characters in the 'host + port' section of authority. + // See https://tools.ietf.org/html/rfc3986#section-3.2 + RawTarget = target.GetAsciiStringNonNullCharacters(); + Path = string.Empty; + QueryString = string.Empty; + } + + private void OnAsteriskFormTarget(HttpMethod method) + { + _requestTargetForm = HttpRequestTarget.AsteriskForm; + + // The asterisk-form of request-target is only used for a server-wide + // OPTIONS request (https://tools.ietf.org/html/rfc7231#section-4.3.7). + if (method != HttpMethod.Options) + { + BadHttpRequestException.Throw(RequestRejectionReason.OptionsMethodRequired); + } + + RawTarget = Asterisk; + Path = string.Empty; + QueryString = string.Empty; + } + + private void OnAbsoluteFormTarget(Span target, Span query) + { + _requestTargetForm = HttpRequestTarget.AbsoluteForm; + + // absolute-form + // https://tools.ietf.org/html/rfc7230#section-5.3.2 + + // This code should be the edge-case. + + // From the spec: + // a server MUST accept the absolute-form in requests, even though + // HTTP/1.1 clients will only send them in requests to proxies. + + RawTarget = target.GetAsciiStringNonNullCharacters(); + + // Validation of absolute URIs is slow, but clients + // should not be sending this form anyways, so perf optimization + // not high priority + + if (!Uri.TryCreate(RawTarget, UriKind.Absolute, out var uri)) + { + ThrowRequestTargetRejected(target); + } + + _absoluteRequestTarget = uri; + Path = uri.LocalPath; + // don't use uri.Query because we need the unescaped version + QueryString = query.GetAsciiStringNonNullCharacters(); + } + + private static unsafe string GetUtf8String(Span path) + { + // .NET 451 doesn't have pointer overloads for Encoding.GetString so we + // copy to an array + fixed (byte* pointer = &MemoryMarshal.GetReference(path)) + { + return Encoding.UTF8.GetString(pointer, path.Length); + } + } + + internal void EnsureHostHeaderExists() + { + // https://tools.ietf.org/html/rfc7230#section-5.4 + // A server MUST respond with a 400 (Bad Request) status code to any + // HTTP/1.1 request message that lacks a Host header field and to any + // request message that contains more than one Host header field or a + // Host header field with an invalid field-value. + + var hostCount = HttpRequestHeaders.HostCount; + var hostText = HttpRequestHeaders.HeaderHost.ToString(); + if (hostCount <= 0) + { + if (_httpVersion == Http.HttpVersion.Http10) + { + return; + } + BadHttpRequestException.Throw(RequestRejectionReason.MissingHostHeader); + } + else if (hostCount > 1) + { + BadHttpRequestException.Throw(RequestRejectionReason.MultipleHostHeaders); + } + else if (_requestTargetForm != HttpRequestTarget.OriginForm) + { + // Tail call + ValidateNonOrginHostHeader(hostText); + } + else + { + // Tail call + HttpUtilities.ValidateHostHeader(hostText); + } + } + + private void ValidateNonOrginHostHeader(string hostText) + { + if (_requestTargetForm == HttpRequestTarget.AuthorityForm) + { + if (hostText != RawTarget) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + } + else if (_requestTargetForm == HttpRequestTarget.AbsoluteForm) + { + // If the target URI includes an authority component, then a + // client MUST send a field - value for Host that is identical to that + // authority component, excluding any userinfo subcomponent and its "@" + // delimiter. + + // System.Uri doesn't not tell us if the port was in the original string or not. + // When IsDefaultPort = true, we will allow Host: with or without the default port + if (hostText != _absoluteRequestTarget.Authority) + { + if (!_absoluteRequestTarget.IsDefaultPort + || hostText != _absoluteRequestTarget.Authority + ":" + _absoluteRequestTarget.Port.ToString(CultureInfo.InvariantCulture)) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + } + } + + // Tail call + HttpUtilities.ValidateHostHeader(hostText); + } + + protected override void OnReset() + { + ResetIHttpUpgradeFeature(); + + _requestTimedOut = false; + _requestTargetForm = HttpRequestTarget.Unknown; + _absoluteRequestTarget = null; + _remainingRequestHeadersBytesAllowed = ServerOptions.Limits.MaxRequestHeadersTotalSize + 2; + _requestCount++; + } + + protected override void OnRequestProcessingEnding() + { + Input.Complete(); + } + + protected override string CreateRequestId() + => StringUtilities.ConcatAsHexSuffix(ConnectionId, ':', _requestCount); + + protected override MessageBody CreateMessageBody() + => Http1MessageBody.For(_httpVersion, HttpRequestHeaders, this); + + protected override void BeginRequestProcessing() + { + // Reset the features and timeout. + Reset(); + TimeoutControl.SetTimeout(_keepAliveTicks, TimeoutAction.StopProcessingNextRequest); + } + + protected override bool BeginRead(out ValueTask awaitable) + { + awaitable = Input.ReadAsync(); + return true; + } + + protected override bool TryParseRequest(ReadResult result, out bool endConnection) + { + var examined = result.Buffer.End; + var consumed = result.Buffer.End; + + try + { + ParseRequest(result.Buffer, out consumed, out examined); + } + catch (InvalidOperationException) + { + if (_requestProcessingStatus == RequestProcessingStatus.ParsingHeaders) + { + BadHttpRequestException.Throw(RequestRejectionReason.MalformedRequestInvalidHeaders); + } + throw; + } + finally + { + Input.AdvanceTo(consumed, examined); + } + + if (result.IsCompleted) + { + switch (_requestProcessingStatus) + { + case RequestProcessingStatus.RequestPending: + endConnection = true; + return true; + case RequestProcessingStatus.ParsingRequestLine: + BadHttpRequestException.Throw(RequestRejectionReason.InvalidRequestLine); + break; + case RequestProcessingStatus.ParsingHeaders: + BadHttpRequestException.Throw(RequestRejectionReason.MalformedRequestInvalidHeaders); + break; + } + } + else if (!_keepAlive && _requestProcessingStatus == RequestProcessingStatus.RequestPending) + { + // Stop the request processing loop if the server is shutting down or there was a keep-alive timeout + // and there is no ongoing request. + endConnection = true; + return true; + } + else if (RequestTimedOut) + { + // In this case, there is an ongoing request but the start line/header parsing has timed out, so send + // a 408 response. + BadHttpRequestException.Throw(RequestRejectionReason.RequestHeadersTimeout); + } + + endConnection = false; + if (_requestProcessingStatus == RequestProcessingStatus.AppStarted) + { + EnsureHostHeaderExists(); + return true; + } + else + { + return false; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ConnectionContext.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ConnectionContext.cs new file mode 100644 index 0000000000..28b5f95ed4 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ConnectionContext.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Net; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public class Http1ConnectionContext : IHttpProtocolContext + { + public string ConnectionId { get; set; } + public ServiceContext ServiceContext { get; set; } + public ConnectionContext ConnectionContext { get; set; } + public IFeatureCollection ConnectionFeatures { get; set; } + public MemoryPool MemoryPool { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public ITimeoutControl TimeoutControl { get; set; } + public IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs new file mode 100644 index 0000000000..f11619d7f3 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs @@ -0,0 +1,716 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public abstract class Http1MessageBody : MessageBody + { + private readonly Http1Connection _context; + + private volatile bool _canceled; + private Task _pumpTask; + + protected Http1MessageBody(Http1Connection context) + : base(context) + { + _context = context; + } + + private async Task PumpAsync() + { + Exception error = null; + + try + { + var awaitable = _context.Input.ReadAsync(); + + if (!awaitable.IsCompleted) + { + TryProduceContinue(); + } + + TryStartTimingReads(); + + while (true) + { + var result = await awaitable; + + if (_context.RequestTimedOut) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); + } + + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (_canceled) + { + break; + } + + if (!readableBuffer.IsEmpty) + { + bool done; + done = Read(readableBuffer, _context.RequestBodyPipe.Writer, out consumed, out examined); + + var writeAwaitable = _context.RequestBodyPipe.Writer.FlushAsync(); + var backpressure = false; + + if (!writeAwaitable.IsCompleted) + { + // Backpressure, stop controlling incoming data rate until data is read. + backpressure = true; + TryPauseTimingReads(); + } + + await writeAwaitable; + + if (backpressure) + { + TryResumeTimingReads(); + } + + if (done) + { + break; + } + } + + // Read() will have already have greedily consumed the entire request body if able. + if (result.IsCompleted) + { + // Treat any FIN from an upgraded request as expected. + // It's up to higher-level consumer (i.e. WebSocket middleware) to determine + // if the end is actually expected based on higher-level framing. + if (RequestUpgrade) + { + break; + } + + BadHttpRequestException.Throw(RequestRejectionReason.UnexpectedEndOfRequestContent); + } + + } + finally + { + _context.Input.AdvanceTo(consumed, examined); + } + + awaitable = _context.Input.ReadAsync(); + } + } + catch (Exception ex) + { + error = ex; + } + finally + { + _context.RequestBodyPipe.Writer.Complete(error); + TryStopTimingReads(); + } + } + + public override Task StopAsync() + { + if (!_context.HasStartedConsumingRequestBody) + { + return Task.CompletedTask; + } + + _canceled = true; + _context.Input.CancelPendingRead(); + return _pumpTask; + } + + protected override Task OnConsumeAsync() + { + try + { + if (_context.RequestBodyPipe.Reader.TryRead(out var readResult)) + { + _context.RequestBodyPipe.Reader.AdvanceTo(readResult.Buffer.End); + + if (readResult.IsCompleted) + { + return Task.CompletedTask; + } + } + } + catch (BadHttpRequestException ex) + { + // At this point, the response has already been written, so this won't result in a 4XX response; + // however, we still need to stop the request processing loop and log. + _context.SetBadRequestState(ex); + return Task.CompletedTask; + } + + return OnConsumeAsyncAwaited(); + } + + private async Task OnConsumeAsyncAwaited() + { + Log.RequestBodyNotEntirelyRead(_context.ConnectionIdFeature, _context.TraceIdentifier); + + _context.TimeoutControl.SetTimeout(Constants.RequestBodyDrainTimeout.Ticks, TimeoutAction.AbortConnection); + + try + { + ReadResult result; + do + { + result = await _context.RequestBodyPipe.Reader.ReadAsync(); + _context.RequestBodyPipe.Reader.AdvanceTo(result.Buffer.End); + } while (!result.IsCompleted); + } + catch (BadHttpRequestException ex) + { + _context.SetBadRequestState(ex); + } + catch (ConnectionAbortedException) + { + Log.RequestBodyDrainTimedOut(_context.ConnectionIdFeature, _context.TraceIdentifier); + } + finally + { + _context.TimeoutControl.CancelTimeout(); + } + } + + protected void Copy(ReadOnlySequence readableBuffer, PipeWriter writableBuffer) + { + _context.TimeoutControl.BytesRead(readableBuffer.Length); + + if (readableBuffer.IsSingleSegment) + { + writableBuffer.Write(readableBuffer.First.Span); + } + else + { + foreach (var memory in readableBuffer) + { + writableBuffer.Write(memory.Span); + } + } + } + + protected override void OnReadStarted() + { + _pumpTask = PumpAsync(); + } + + protected virtual bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + throw new NotImplementedException(); + } + + private void TryStartTimingReads() + { + if (!RequestUpgrade) + { + Log.RequestBodyStart(_context.ConnectionIdFeature, _context.TraceIdentifier); + _context.TimeoutControl.StartTimingReads(); + } + } + + private void TryPauseTimingReads() + { + if (!RequestUpgrade) + { + _context.TimeoutControl.PauseTimingReads(); + } + } + + private void TryResumeTimingReads() + { + if (!RequestUpgrade) + { + _context.TimeoutControl.ResumeTimingReads(); + } + } + + private void TryStopTimingReads() + { + if (!RequestUpgrade) + { + Log.RequestBodyDone(_context.ConnectionIdFeature, _context.TraceIdentifier); + _context.TimeoutControl.StopTimingReads(); + } + } + + public static MessageBody For( + HttpVersion httpVersion, + HttpRequestHeaders headers, + Http1Connection context) + { + // see also http://tools.ietf.org/html/rfc2616#section-4.4 + var keepAlive = httpVersion != HttpVersion.Http10; + + var upgrade = false; + if (headers.HasConnection) + { + var connectionOptions = HttpHeaders.ParseConnection(headers.HeaderConnection); + + upgrade = (connectionOptions & ConnectionOptions.Upgrade) == ConnectionOptions.Upgrade; + keepAlive = (connectionOptions & ConnectionOptions.KeepAlive) == ConnectionOptions.KeepAlive; + } + + if (upgrade) + { + if (headers.HeaderTransferEncoding.Count > 0 || (headers.ContentLength.HasValue && headers.ContentLength.Value != 0)) + { + BadHttpRequestException.Throw(RequestRejectionReason.UpgradeRequestCannotHavePayload); + } + + return new ForUpgrade(context); + } + + if (headers.HasTransferEncoding) + { + var transferEncoding = headers.HeaderTransferEncoding; + var transferCoding = HttpHeaders.GetFinalTransferCoding(transferEncoding); + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a Transfer-Encoding header field + // is present in a request and the chunked transfer coding is not + // the final encoding, the message body length cannot be determined + // reliably; the server MUST respond with the 400 (Bad Request) + // status code and then close the connection. + if (transferCoding != TransferCoding.Chunked) + { + BadHttpRequestException.Throw(RequestRejectionReason.FinalTransferCodingNotChunked, in transferEncoding); + } + + return new ForChunkedEncoding(keepAlive, context); + } + + if (headers.ContentLength.HasValue) + { + var contentLength = headers.ContentLength.Value; + + if (contentLength == 0) + { + return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose; + } + + return new ForContentLength(keepAlive, contentLength, context); + } + + // Avoid slowing down most common case + if (!object.ReferenceEquals(context.Method, HttpMethods.Get)) + { + // If we got here, request contains no Content-Length or Transfer-Encoding header. + // Reject with 411 Length Required. + if (context.Method == HttpMethod.Post || context.Method == HttpMethod.Put) + { + var requestRejectionReason = httpVersion == HttpVersion.Http11 ? RequestRejectionReason.LengthRequired : RequestRejectionReason.LengthRequiredHttp10; + BadHttpRequestException.Throw(requestRejectionReason, context.Method); + } + } + + return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose; + } + + private class ForUpgrade : Http1MessageBody + { + public ForUpgrade(Http1Connection context) + : base(context) + { + RequestUpgrade = true; + } + + public override bool IsEmpty => true; + + protected override bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + Copy(readableBuffer, writableBuffer); + consumed = readableBuffer.End; + examined = readableBuffer.End; + return false; + } + } + + private class ForContentLength : Http1MessageBody + { + private readonly long _contentLength; + private long _inputLength; + + public ForContentLength(bool keepAlive, long contentLength, Http1Connection context) + : base(context) + { + RequestKeepAlive = keepAlive; + _contentLength = contentLength; + _inputLength = _contentLength; + } + + protected override bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + if (_inputLength == 0) + { + throw new InvalidOperationException("Attempted to read from completed Content-Length request body."); + } + + var actual = (int)Math.Min(readableBuffer.Length, _inputLength); + _inputLength -= actual; + + consumed = readableBuffer.GetPosition(actual); + examined = consumed; + + Copy(readableBuffer.Slice(0, actual), writableBuffer); + + return _inputLength == 0; + } + + protected override void OnReadStarting() + { + if (_contentLength > _context.MaxRequestBodySize) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); + } + } + } + + /// + /// http://tools.ietf.org/html/rfc2616#section-3.6.1 + /// + private class ForChunkedEncoding : Http1MessageBody + { + // byte consts don't have a data type annotation so we pre-cast it + private const byte ByteCR = (byte)'\r'; + // "7FFFFFFF\r\n" is the largest chunk size that could be returned as an int. + private const int MaxChunkPrefixBytes = 10; + + private long _inputLength; + private long _consumedBytes; + + private Mode _mode = Mode.Prefix; + + public ForChunkedEncoding(bool keepAlive, Http1Connection context) + : base(context) + { + RequestKeepAlive = keepAlive; + } + + protected override bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = default(SequencePosition); + examined = default(SequencePosition); + + while (_mode < Mode.Trailer) + { + if (_mode == Mode.Prefix) + { + ParseChunkedPrefix(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Prefix) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Extension) + { + ParseExtension(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Extension) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Data) + { + ReadChunkedData(readableBuffer, writableBuffer, out consumed, out examined); + + if (_mode == Mode.Data) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + if (_mode == Mode.Suffix) + { + ParseChunkedSuffix(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Suffix) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + } + + // Chunks finished, parse trailers + if (_mode == Mode.Trailer) + { + ParseChunkedTrailer(readableBuffer, out consumed, out examined); + + if (_mode == Mode.Trailer) + { + return false; + } + + readableBuffer = readableBuffer.Slice(consumed); + } + + // _consumedBytes aren't tracked for trailer headers, since headers have separate limits. + if (_mode == Mode.TrailerHeaders) + { + if (_context.TakeMessageHeaders(readableBuffer, out consumed, out examined)) + { + _mode = Mode.Complete; + } + } + + return _mode == Mode.Complete; + } + + private void AddAndCheckConsumedBytes(long consumedBytes) + { + _consumedBytes += consumedBytes; + + if (_consumedBytes > _context.MaxRequestBodySize) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); + } + } + + private void ParseChunkedPrefix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + var reader = new BufferReader(buffer); + var ch1 = reader.Read(); + var ch2 = reader.Read(); + + if (ch1 == -1 || ch2 == -1) + { + examined = reader.Position; + return; + } + + var chunkSize = CalculateChunkSize(ch1, 0); + ch1 = ch2; + + while (reader.ConsumedBytes < MaxChunkPrefixBytes) + { + if (ch1 == ';') + { + consumed = reader.Position; + examined = reader.Position; + + AddAndCheckConsumedBytes(reader.ConsumedBytes); + _inputLength = chunkSize; + _mode = Mode.Extension; + return; + } + + ch2 = reader.Read(); + if (ch2 == -1) + { + examined = reader.Position; + return; + } + + if (ch1 == '\r' && ch2 == '\n') + { + consumed = reader.Position; + examined = reader.Position; + + AddAndCheckConsumedBytes(reader.ConsumedBytes); + _inputLength = chunkSize; + _mode = chunkSize > 0 ? Mode.Data : Mode.Trailer; + return; + } + + chunkSize = CalculateChunkSize(ch1, chunkSize); + ch1 = ch2; + } + + // At this point, 10 bytes have been consumed which is enough to parse the max value "7FFFFFFF\r\n". + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); + } + + private void ParseExtension(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + // Chunk-extensions not currently parsed + // Just drain the data + consumed = buffer.Start; + examined = buffer.Start; + + do + { + SequencePosition? extensionCursorPosition = buffer.PositionOf(ByteCR); + if (extensionCursorPosition == null) + { + // End marker not found yet + consumed = buffer.End; + examined = buffer.End; + AddAndCheckConsumedBytes(buffer.Length); + return; + }; + + var extensionCursor = extensionCursorPosition.Value; + var charsToByteCRExclusive = buffer.Slice(0, extensionCursor).Length; + + var sufixBuffer = buffer.Slice(extensionCursor); + if (sufixBuffer.Length < 2) + { + consumed = extensionCursor; + examined = buffer.End; + AddAndCheckConsumedBytes(charsToByteCRExclusive); + return; + } + + sufixBuffer = sufixBuffer.Slice(0, 2); + var sufixSpan = sufixBuffer.ToSpan(); + + if (sufixSpan[1] == '\n') + { + // We consumed the \r\n at the end of the extension, so switch modes. + _mode = _inputLength > 0 ? Mode.Data : Mode.Trailer; + + consumed = sufixBuffer.End; + examined = sufixBuffer.End; + AddAndCheckConsumedBytes(charsToByteCRExclusive + 2); + } + else + { + // Don't consume suffixSpan[1] in case it is also a \r. + buffer = buffer.Slice(charsToByteCRExclusive + 1); + consumed = extensionCursor; + AddAndCheckConsumedBytes(charsToByteCRExclusive + 1); + } + } while (_mode == Mode.Extension); + } + + private void ReadChunkedData(ReadOnlySequence buffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + var actual = Math.Min(buffer.Length, _inputLength); + consumed = buffer.GetPosition(actual); + examined = consumed; + + Copy(buffer.Slice(0, actual), writableBuffer); + + _inputLength -= actual; + AddAndCheckConsumedBytes(actual); + + if (_inputLength == 0) + { + _mode = Mode.Suffix; + } + } + + private void ParseChunkedSuffix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + + if (buffer.Length < 2) + { + examined = buffer.End; + return; + } + + var suffixBuffer = buffer.Slice(0, 2); + var suffixSpan = suffixBuffer.ToSpan(); + if (suffixSpan[0] == '\r' && suffixSpan[1] == '\n') + { + consumed = suffixBuffer.End; + examined = suffixBuffer.End; + AddAndCheckConsumedBytes(2); + _mode = Mode.Prefix; + } + else + { + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSuffix); + } + } + + private void ParseChunkedTrailer(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.Start; + + if (buffer.Length < 2) + { + examined = buffer.End; + return; + } + + var trailerBuffer = buffer.Slice(0, 2); + var trailerSpan = trailerBuffer.ToSpan(); + + if (trailerSpan[0] == '\r' && trailerSpan[1] == '\n') + { + consumed = trailerBuffer.End; + examined = trailerBuffer.End; + AddAndCheckConsumedBytes(2); + _mode = Mode.Complete; + } + else + { + _mode = Mode.TrailerHeaders; + } + } + + private int CalculateChunkSize(int extraHexDigit, int currentParsedSize) + { + try + { + checked + { + if (extraHexDigit >= '0' && extraHexDigit <= '9') + { + return currentParsedSize * 0x10 + (extraHexDigit - '0'); + } + else if (extraHexDigit >= 'A' && extraHexDigit <= 'F') + { + return currentParsedSize * 0x10 + (extraHexDigit - ('A' - 10)); + } + else if (extraHexDigit >= 'a' && extraHexDigit <= 'f') + { + return currentParsedSize * 0x10 + (extraHexDigit - ('a' - 10)); + } + } + } + catch (OverflowException ex) + { + throw new IOException(CoreStrings.BadRequest_BadChunkSizeData, ex); + } + + BadHttpRequestException.Throw(RequestRejectionReason.BadChunkSizeData); + return -1; // can't happen, but compiler complains + } + + private enum Mode + { + Prefix, + Extension, + Data, + Suffix, + Trailer, + TrailerHeaders, + Complete + }; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs new file mode 100644 index 0000000000..685980d1c6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs @@ -0,0 +1,294 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public class Http1OutputProducer : IHttpOutputProducer + { + private static readonly ReadOnlyMemory _continueBytes = new ReadOnlyMemory(Encoding.ASCII.GetBytes("HTTP/1.1 100 Continue\r\n\r\n")); + private static readonly byte[] _bytesHttpVersion11 = Encoding.ASCII.GetBytes("HTTP/1.1 "); + private static readonly byte[] _bytesEndHeaders = Encoding.ASCII.GetBytes("\r\n\r\n"); + private static readonly ReadOnlyMemory _endChunkedResponseBytes = new ReadOnlyMemory(Encoding.ASCII.GetBytes("0\r\n\r\n")); + + private readonly string _connectionId; + private readonly ConnectionContext _connectionContext; + private readonly ITimeoutControl _timeoutControl; + private readonly IKestrelTrace _log; + private readonly IBytesWrittenFeature _transportBytesWrittenFeature; + + // This locks access to to all of the below fields + private readonly object _contextLock = new object(); + + private bool _completed = false; + private bool _aborted; + private long _unflushedBytes; + private long _totalBytesCommitted; + + private readonly PipeWriter _pipeWriter; + + // https://github.com/dotnet/corefxlab/issues/1334 + // Pipelines don't support multiple awaiters on flush + // this is temporary until it does + private TaskCompletionSource _flushTcs; + private readonly object _flushLock = new object(); + private Action _flushCompleted; + + private ValueTask _flushTask; + + public Http1OutputProducer( + PipeWriter pipeWriter, + string connectionId, + ConnectionContext connectionContext, + IKestrelTrace log, + ITimeoutControl timeoutControl, + IBytesWrittenFeature transportBytesWrittenFeature) + { + _pipeWriter = pipeWriter; + _connectionId = connectionId; + _connectionContext = connectionContext; + _timeoutControl = timeoutControl; + _log = log; + _flushCompleted = OnFlushCompleted; + _transportBytesWrittenFeature = transportBytesWrittenFeature; + } + + public Task WriteDataAsync(ReadOnlySpan buffer, CancellationToken cancellationToken = default(CancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + return WriteAsync(buffer, cancellationToken); + } + + public Task WriteStreamSuffixAsync(CancellationToken cancellationToken) + { + return WriteAsync(_endChunkedResponseBytes.Span, cancellationToken); + } + + public Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + return WriteAsync(Constants.EmptyData, cancellationToken); + } + + public void Write(Func callback, T state) + { + lock (_contextLock) + { + if (_completed) + { + return; + } + + var buffer = _pipeWriter; + var bytesCommitted = callback(buffer, state); + _unflushedBytes += bytesCommitted; + _totalBytesCommitted += bytesCommitted; + } + } + + public Task WriteAsync(Func callback, T state) + { + lock (_contextLock) + { + if (_completed) + { + return Task.CompletedTask; + } + + var buffer = _pipeWriter; + var bytesCommitted = callback(buffer, state); + _unflushedBytes += bytesCommitted; + _totalBytesCommitted += bytesCommitted; + } + + return FlushAsync(); + } + + public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders) + { + lock (_contextLock) + { + if (_completed) + { + return; + } + + var buffer = _pipeWriter; + var writer = new BufferWriter(buffer); + + writer.Write(_bytesHttpVersion11); + var statusBytes = ReasonPhrases.ToStatusBytes(statusCode, reasonPhrase); + writer.Write(statusBytes); + responseHeaders.CopyTo(ref writer); + writer.Write(_bytesEndHeaders); + + writer.Commit(); + + _unflushedBytes += writer.BytesCommitted; + _totalBytesCommitted += writer.BytesCommitted; + } + } + + public void Dispose() + { + lock (_contextLock) + { + if (_completed) + { + return; + } + + _log.ConnectionDisconnect(_connectionId); + _completed = true; + _pipeWriter.Complete(); + + var unsentBytes = _totalBytesCommitted - _transportBytesWrittenFeature.TotalBytesWritten; + + if (unsentBytes > 0) + { + // unsentBytes should never be over 64KB in the default configuration. + _timeoutControl.StartTimingWrite((int)Math.Min(unsentBytes, int.MaxValue)); + _pipeWriter.OnReaderCompleted((ex, state) => ((ITimeoutControl)state).StopTimingWrite(), _timeoutControl); + } + } + } + + public void Abort(ConnectionAbortedException error) + { + // Abort can be called after Dispose if there's a flush timeout. + // It's important to still call _lifetimeFeature.Abort() in this case. + + lock (_contextLock) + { + if (_aborted) + { + return; + } + + _aborted = true; + _connectionContext.Abort(error); + + if (!_completed) + { + _log.ConnectionDisconnect(_connectionId); + _completed = true; + _pipeWriter.Complete(); + } + } + } + + public Task Write100ContinueAsync(CancellationToken cancellationToken) + { + return WriteAsync(_continueBytes.Span, default(CancellationToken)); + } + + private Task WriteAsync( + ReadOnlySpan buffer, + CancellationToken cancellationToken) + { + var writableBuffer = default(PipeWriter); + long bytesWritten = 0; + lock (_contextLock) + { + if (_completed) + { + return Task.CompletedTask; + } + + writableBuffer = _pipeWriter; + var writer = new BufferWriter(writableBuffer); + if (buffer.Length > 0) + { + writer.Write(buffer); + + _unflushedBytes += buffer.Length; + _totalBytesCommitted += buffer.Length; + } + writer.Commit(); + + bytesWritten = _unflushedBytes; + _unflushedBytes = 0; + } + + return FlushAsync(writableBuffer, bytesWritten, cancellationToken); + } + + // Single caller, at end of method - so inline + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private Task FlushAsync(PipeWriter writableBuffer, long bytesWritten, CancellationToken cancellationToken) + { + var awaitable = writableBuffer.FlushAsync(cancellationToken); + if (awaitable.IsCompleted) + { + // The flush task can't fail today + return Task.CompletedTask; + } + return FlushAsyncAwaited(awaitable, bytesWritten, cancellationToken); + } + + private async Task FlushAsyncAwaited(ValueTask awaitable, long count, CancellationToken cancellationToken) + { + // https://github.com/dotnet/corefxlab/issues/1334 + // Since the flush awaitable doesn't currently support multiple awaiters + // we need to use a task to track the callbacks. + // All awaiters get the same task + lock (_flushLock) + { + _flushTask = awaitable; + if (_flushTcs == null || _flushTcs.Task.IsCompleted) + { + _flushTcs = new TaskCompletionSource(); + + _flushTask.GetAwaiter().OnCompleted(_flushCompleted); + } + } + + _timeoutControl.StartTimingWrite(count); + try + { + await _flushTcs.Task; + cancellationToken.ThrowIfCancellationRequested(); + } + catch (OperationCanceledException) + { + _completed = true; + throw; + } + finally + { + _timeoutControl.StopTimingWrite(); + } + } + + private void OnFlushCompleted() + { + try + { + _flushTask.GetAwaiter().GetResult(); + _flushTcs.TrySetResult(null); + } + catch (Exception exception) + { + _flushTcs.TrySetResult(exception); + } + finally + { + _flushTask = default; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ParsingHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ParsingHandler.cs new file mode 100644 index 0000000000..e4385351db --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ParsingHandler.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public struct Http1ParsingHandler : IHttpRequestLineHandler, IHttpHeadersHandler + { + public Http1Connection Connection; + + public Http1ParsingHandler(Http1Connection connection) + { + Connection = connection; + } + + public void OnHeader(Span name, Span value) + => Connection.OnHeader(name, value); + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + => Connection.OnStartLine(method, version, target, path, query, customMethod, pathEncoded); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs new file mode 100644 index 0000000000..d84f15706d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs @@ -0,0 +1,9001 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using System.Buffers; +using System.IO.Pipelines; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + + public partial class HttpRequestHeaders + { + + private long _bits = 0; + private HeaderReferences _headers; + + public bool HasConnection => (_bits & 2L) != 0; + public bool HasTransferEncoding => (_bits & 64L) != 0; + + public int HostCount => _headers._Host.Count; + + public StringValues HeaderCacheControl + { + get + { + StringValues value; + if ((_bits & 1L) != 0) + { + value = _headers._CacheControl; + } + return value; + } + set + { + _bits |= 1L; + _headers._CacheControl = value; + } + } + public StringValues HeaderConnection + { + get + { + StringValues value; + if ((_bits & 2L) != 0) + { + value = _headers._Connection; + } + return value; + } + set + { + _bits |= 2L; + _headers._Connection = value; + } + } + public StringValues HeaderDate + { + get + { + StringValues value; + if ((_bits & 4L) != 0) + { + value = _headers._Date; + } + return value; + } + set + { + _bits |= 4L; + _headers._Date = value; + } + } + public StringValues HeaderKeepAlive + { + get + { + StringValues value; + if ((_bits & 8L) != 0) + { + value = _headers._KeepAlive; + } + return value; + } + set + { + _bits |= 8L; + _headers._KeepAlive = value; + } + } + public StringValues HeaderPragma + { + get + { + StringValues value; + if ((_bits & 16L) != 0) + { + value = _headers._Pragma; + } + return value; + } + set + { + _bits |= 16L; + _headers._Pragma = value; + } + } + public StringValues HeaderTrailer + { + get + { + StringValues value; + if ((_bits & 32L) != 0) + { + value = _headers._Trailer; + } + return value; + } + set + { + _bits |= 32L; + _headers._Trailer = value; + } + } + public StringValues HeaderTransferEncoding + { + get + { + StringValues value; + if ((_bits & 64L) != 0) + { + value = _headers._TransferEncoding; + } + return value; + } + set + { + _bits |= 64L; + _headers._TransferEncoding = value; + } + } + public StringValues HeaderUpgrade + { + get + { + StringValues value; + if ((_bits & 128L) != 0) + { + value = _headers._Upgrade; + } + return value; + } + set + { + _bits |= 128L; + _headers._Upgrade = value; + } + } + public StringValues HeaderVia + { + get + { + StringValues value; + if ((_bits & 256L) != 0) + { + value = _headers._Via; + } + return value; + } + set + { + _bits |= 256L; + _headers._Via = value; + } + } + public StringValues HeaderWarning + { + get + { + StringValues value; + if ((_bits & 512L) != 0) + { + value = _headers._Warning; + } + return value; + } + set + { + _bits |= 512L; + _headers._Warning = value; + } + } + public StringValues HeaderAllow + { + get + { + StringValues value; + if ((_bits & 1024L) != 0) + { + value = _headers._Allow; + } + return value; + } + set + { + _bits |= 1024L; + _headers._Allow = value; + } + } + public StringValues HeaderContentType + { + get + { + StringValues value; + if ((_bits & 2048L) != 0) + { + value = _headers._ContentType; + } + return value; + } + set + { + _bits |= 2048L; + _headers._ContentType = value; + } + } + public StringValues HeaderContentEncoding + { + get + { + StringValues value; + if ((_bits & 4096L) != 0) + { + value = _headers._ContentEncoding; + } + return value; + } + set + { + _bits |= 4096L; + _headers._ContentEncoding = value; + } + } + public StringValues HeaderContentLanguage + { + get + { + StringValues value; + if ((_bits & 8192L) != 0) + { + value = _headers._ContentLanguage; + } + return value; + } + set + { + _bits |= 8192L; + _headers._ContentLanguage = value; + } + } + public StringValues HeaderContentLocation + { + get + { + StringValues value; + if ((_bits & 16384L) != 0) + { + value = _headers._ContentLocation; + } + return value; + } + set + { + _bits |= 16384L; + _headers._ContentLocation = value; + } + } + public StringValues HeaderContentMD5 + { + get + { + StringValues value; + if ((_bits & 32768L) != 0) + { + value = _headers._ContentMD5; + } + return value; + } + set + { + _bits |= 32768L; + _headers._ContentMD5 = value; + } + } + public StringValues HeaderContentRange + { + get + { + StringValues value; + if ((_bits & 65536L) != 0) + { + value = _headers._ContentRange; + } + return value; + } + set + { + _bits |= 65536L; + _headers._ContentRange = value; + } + } + public StringValues HeaderExpires + { + get + { + StringValues value; + if ((_bits & 131072L) != 0) + { + value = _headers._Expires; + } + return value; + } + set + { + _bits |= 131072L; + _headers._Expires = value; + } + } + public StringValues HeaderLastModified + { + get + { + StringValues value; + if ((_bits & 262144L) != 0) + { + value = _headers._LastModified; + } + return value; + } + set + { + _bits |= 262144L; + _headers._LastModified = value; + } + } + public StringValues HeaderAccept + { + get + { + StringValues value; + if ((_bits & 524288L) != 0) + { + value = _headers._Accept; + } + return value; + } + set + { + _bits |= 524288L; + _headers._Accept = value; + } + } + public StringValues HeaderAcceptCharset + { + get + { + StringValues value; + if ((_bits & 1048576L) != 0) + { + value = _headers._AcceptCharset; + } + return value; + } + set + { + _bits |= 1048576L; + _headers._AcceptCharset = value; + } + } + public StringValues HeaderAcceptEncoding + { + get + { + StringValues value; + if ((_bits & 2097152L) != 0) + { + value = _headers._AcceptEncoding; + } + return value; + } + set + { + _bits |= 2097152L; + _headers._AcceptEncoding = value; + } + } + public StringValues HeaderAcceptLanguage + { + get + { + StringValues value; + if ((_bits & 4194304L) != 0) + { + value = _headers._AcceptLanguage; + } + return value; + } + set + { + _bits |= 4194304L; + _headers._AcceptLanguage = value; + } + } + public StringValues HeaderAuthorization + { + get + { + StringValues value; + if ((_bits & 8388608L) != 0) + { + value = _headers._Authorization; + } + return value; + } + set + { + _bits |= 8388608L; + _headers._Authorization = value; + } + } + public StringValues HeaderCookie + { + get + { + StringValues value; + if ((_bits & 16777216L) != 0) + { + value = _headers._Cookie; + } + return value; + } + set + { + _bits |= 16777216L; + _headers._Cookie = value; + } + } + public StringValues HeaderExpect + { + get + { + StringValues value; + if ((_bits & 33554432L) != 0) + { + value = _headers._Expect; + } + return value; + } + set + { + _bits |= 33554432L; + _headers._Expect = value; + } + } + public StringValues HeaderFrom + { + get + { + StringValues value; + if ((_bits & 67108864L) != 0) + { + value = _headers._From; + } + return value; + } + set + { + _bits |= 67108864L; + _headers._From = value; + } + } + public StringValues HeaderHost + { + get + { + StringValues value; + if ((_bits & 134217728L) != 0) + { + value = _headers._Host; + } + return value; + } + set + { + _bits |= 134217728L; + _headers._Host = value; + } + } + public StringValues HeaderIfMatch + { + get + { + StringValues value; + if ((_bits & 268435456L) != 0) + { + value = _headers._IfMatch; + } + return value; + } + set + { + _bits |= 268435456L; + _headers._IfMatch = value; + } + } + public StringValues HeaderIfModifiedSince + { + get + { + StringValues value; + if ((_bits & 536870912L) != 0) + { + value = _headers._IfModifiedSince; + } + return value; + } + set + { + _bits |= 536870912L; + _headers._IfModifiedSince = value; + } + } + public StringValues HeaderIfNoneMatch + { + get + { + StringValues value; + if ((_bits & 1073741824L) != 0) + { + value = _headers._IfNoneMatch; + } + return value; + } + set + { + _bits |= 1073741824L; + _headers._IfNoneMatch = value; + } + } + public StringValues HeaderIfRange + { + get + { + StringValues value; + if ((_bits & 2147483648L) != 0) + { + value = _headers._IfRange; + } + return value; + } + set + { + _bits |= 2147483648L; + _headers._IfRange = value; + } + } + public StringValues HeaderIfUnmodifiedSince + { + get + { + StringValues value; + if ((_bits & 4294967296L) != 0) + { + value = _headers._IfUnmodifiedSince; + } + return value; + } + set + { + _bits |= 4294967296L; + _headers._IfUnmodifiedSince = value; + } + } + public StringValues HeaderMaxForwards + { + get + { + StringValues value; + if ((_bits & 8589934592L) != 0) + { + value = _headers._MaxForwards; + } + return value; + } + set + { + _bits |= 8589934592L; + _headers._MaxForwards = value; + } + } + public StringValues HeaderProxyAuthorization + { + get + { + StringValues value; + if ((_bits & 17179869184L) != 0) + { + value = _headers._ProxyAuthorization; + } + return value; + } + set + { + _bits |= 17179869184L; + _headers._ProxyAuthorization = value; + } + } + public StringValues HeaderReferer + { + get + { + StringValues value; + if ((_bits & 34359738368L) != 0) + { + value = _headers._Referer; + } + return value; + } + set + { + _bits |= 34359738368L; + _headers._Referer = value; + } + } + public StringValues HeaderRange + { + get + { + StringValues value; + if ((_bits & 68719476736L) != 0) + { + value = _headers._Range; + } + return value; + } + set + { + _bits |= 68719476736L; + _headers._Range = value; + } + } + public StringValues HeaderTE + { + get + { + StringValues value; + if ((_bits & 137438953472L) != 0) + { + value = _headers._TE; + } + return value; + } + set + { + _bits |= 137438953472L; + _headers._TE = value; + } + } + public StringValues HeaderTranslate + { + get + { + StringValues value; + if ((_bits & 274877906944L) != 0) + { + value = _headers._Translate; + } + return value; + } + set + { + _bits |= 274877906944L; + _headers._Translate = value; + } + } + public StringValues HeaderUserAgent + { + get + { + StringValues value; + if ((_bits & 549755813888L) != 0) + { + value = _headers._UserAgent; + } + return value; + } + set + { + _bits |= 549755813888L; + _headers._UserAgent = value; + } + } + public StringValues HeaderOrigin + { + get + { + StringValues value; + if ((_bits & 1099511627776L) != 0) + { + value = _headers._Origin; + } + return value; + } + set + { + _bits |= 1099511627776L; + _headers._Origin = value; + } + } + public StringValues HeaderAccessControlRequestMethod + { + get + { + StringValues value; + if ((_bits & 2199023255552L) != 0) + { + value = _headers._AccessControlRequestMethod; + } + return value; + } + set + { + _bits |= 2199023255552L; + _headers._AccessControlRequestMethod = value; + } + } + public StringValues HeaderAccessControlRequestHeaders + { + get + { + StringValues value; + if ((_bits & 4398046511104L) != 0) + { + value = _headers._AccessControlRequestHeaders; + } + return value; + } + set + { + _bits |= 4398046511104L; + _headers._AccessControlRequestHeaders = value; + } + } + public StringValues HeaderContentLength + { + get + { + StringValues value; + if (_contentLength.HasValue) + { + value = new StringValues(HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + } + return value; + } + set + { + _contentLength = ParseContentLength(value); + } + } + + protected override int GetCountFast() + { + return (_contentLength.HasValue ? 1 : 0 ) + BitCount(_bits) + (MaybeUnknown?.Count ?? 0); + } + + protected override bool TryGetValueFast(string key, out StringValues value) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) != 0) + { + value = _headers._CacheControl; + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) != 0) + { + value = _headers._ContentRange; + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) != 0) + { + value = _headers._LastModified; + return true; + } + return false; + } + if ("Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) != 0) + { + value = _headers._Authorization; + return true; + } + return false; + } + if ("If-None-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) != 0) + { + value = _headers._IfNoneMatch; + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) != 0) + { + value = _headers._Connection; + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) != 0) + { + value = _headers._KeepAlive; + return true; + } + return false; + } + if ("User-Agent".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 549755813888L) != 0) + { + value = _headers._UserAgent; + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) != 0) + { + value = _headers._Date; + return true; + } + return false; + } + if ("From".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) != 0) + { + value = _headers._From; + return true; + } + return false; + } + if ("Host".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) != 0) + { + value = _headers._Host; + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) != 0) + { + value = _headers._Pragma; + return true; + } + return false; + } + if ("Accept".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) != 0) + { + value = _headers._Accept; + return true; + } + return false; + } + if ("Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) != 0) + { + value = _headers._Cookie; + return true; + } + return false; + } + if ("Expect".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) != 0) + { + value = _headers._Expect; + return true; + } + return false; + } + if ("Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1099511627776L) != 0) + { + value = _headers._Origin; + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) != 0) + { + value = _headers._Trailer; + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) != 0) + { + value = _headers._Upgrade; + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) != 0) + { + value = _headers._Warning; + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) != 0) + { + value = _headers._Expires; + return true; + } + return false; + } + if ("Referer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 34359738368L) != 0) + { + value = _headers._Referer; + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) != 0) + { + value = _headers._TransferEncoding; + return true; + } + return false; + } + if ("If-Modified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) != 0) + { + value = _headers._IfModifiedSince; + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) != 0) + { + value = _headers._Via; + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) != 0) + { + value = _headers._Allow; + return true; + } + return false; + } + if ("Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 68719476736L) != 0) + { + value = _headers._Range; + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) != 0) + { + value = _headers._ContentType; + return true; + } + return false; + } + if ("Max-Forwards".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) != 0) + { + value = _headers._MaxForwards; + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) != 0) + { + value = _headers._ContentEncoding; + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) != 0) + { + value = _headers._ContentLanguage; + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) != 0) + { + value = _headers._ContentLocation; + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) != 0) + { + value = _headers._ContentMD5; + return true; + } + return false; + } + } + break; + case 14: + { + if ("Accept-Charset".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) != 0) + { + value = _headers._AcceptCharset; + return true; + } + return false; + } + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (_contentLength.HasValue) + { + value = HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value); + return true; + } + return false; + } + } + break; + case 15: + { + if ("Accept-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) != 0) + { + value = _headers._AcceptEncoding; + return true; + } + return false; + } + if ("Accept-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) != 0) + { + value = _headers._AcceptLanguage; + return true; + } + return false; + } + } + break; + case 8: + { + if ("If-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) != 0) + { + value = _headers._IfMatch; + return true; + } + return false; + } + if ("If-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) != 0) + { + value = _headers._IfRange; + return true; + } + return false; + } + } + break; + case 19: + { + if ("If-Unmodified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) != 0) + { + value = _headers._IfUnmodifiedSince; + return true; + } + return false; + } + if ("Proxy-Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) != 0) + { + value = _headers._ProxyAuthorization; + return true; + } + return false; + } + } + break; + case 2: + { + if ("TE".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 137438953472L) != 0) + { + value = _headers._TE; + return true; + } + return false; + } + } + break; + case 9: + { + if ("Translate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 274877906944L) != 0) + { + value = _headers._Translate; + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Request-Method".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2199023255552L) != 0) + { + value = _headers._AccessControlRequestMethod; + return true; + } + return false; + } + } + break; + case 30: + { + if ("Access-Control-Request-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4398046511104L) != 0) + { + value = _headers._AccessControlRequestHeaders; + return true; + } + return false; + } + } + break; + } + + return MaybeUnknown?.TryGetValue(key, out value) ?? false; + } + + protected override void SetValueFast(string key, in StringValues value) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1L; + _headers._CacheControl = value; + return; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 65536L; + _headers._ContentRange = value; + return; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 262144L; + _headers._LastModified = value; + return; + } + if ("Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8388608L; + _headers._Authorization = value; + return; + } + if ("If-None-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1073741824L; + _headers._IfNoneMatch = value; + return; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2L; + _headers._Connection = value; + return; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8L; + _headers._KeepAlive = value; + return; + } + if ("User-Agent".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 549755813888L; + _headers._UserAgent = value; + return; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4L; + _headers._Date = value; + return; + } + if ("From".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 67108864L; + _headers._From = value; + return; + } + if ("Host".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 134217728L; + _headers._Host = value; + return; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16L; + _headers._Pragma = value; + return; + } + if ("Accept".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 524288L; + _headers._Accept = value; + return; + } + if ("Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16777216L; + _headers._Cookie = value; + return; + } + if ("Expect".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 33554432L; + _headers._Expect = value; + return; + } + if ("Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1099511627776L; + _headers._Origin = value; + return; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 32L; + _headers._Trailer = value; + return; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 128L; + _headers._Upgrade = value; + return; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 512L; + _headers._Warning = value; + return; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 131072L; + _headers._Expires = value; + return; + } + if ("Referer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 34359738368L; + _headers._Referer = value; + return; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 64L; + _headers._TransferEncoding = value; + return; + } + if ("If-Modified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 536870912L; + _headers._IfModifiedSince = value; + return; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 256L; + _headers._Via = value; + return; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1024L; + _headers._Allow = value; + return; + } + if ("Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 68719476736L; + _headers._Range = value; + return; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2048L; + _headers._ContentType = value; + return; + } + if ("Max-Forwards".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8589934592L; + _headers._MaxForwards = value; + return; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4096L; + _headers._ContentEncoding = value; + return; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8192L; + _headers._ContentLanguage = value; + return; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16384L; + _headers._ContentLocation = value; + return; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 32768L; + _headers._ContentMD5 = value; + return; + } + } + break; + case 14: + { + if ("Accept-Charset".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1048576L; + _headers._AcceptCharset = value; + return; + } + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _contentLength = ParseContentLength(value.ToString()); + return; + } + } + break; + case 15: + { + if ("Accept-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2097152L; + _headers._AcceptEncoding = value; + return; + } + if ("Accept-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4194304L; + _headers._AcceptLanguage = value; + return; + } + } + break; + case 8: + { + if ("If-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 268435456L; + _headers._IfMatch = value; + return; + } + if ("If-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2147483648L; + _headers._IfRange = value; + return; + } + } + break; + case 19: + { + if ("If-Unmodified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4294967296L; + _headers._IfUnmodifiedSince = value; + return; + } + if ("Proxy-Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 17179869184L; + _headers._ProxyAuthorization = value; + return; + } + } + break; + case 2: + { + if ("TE".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 137438953472L; + _headers._TE = value; + return; + } + } + break; + case 9: + { + if ("Translate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 274877906944L; + _headers._Translate = value; + return; + } + } + break; + case 29: + { + if ("Access-Control-Request-Method".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2199023255552L; + _headers._AccessControlRequestMethod = value; + return; + } + } + break; + case 30: + { + if ("Access-Control-Request-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4398046511104L; + _headers._AccessControlRequestHeaders = value; + return; + } + } + break; + } + + SetValueUnknown(key, value); + } + + protected override bool AddValueFast(string key, in StringValues value) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) == 0) + { + _bits |= 1L; + _headers._CacheControl = value; + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) == 0) + { + _bits |= 65536L; + _headers._ContentRange = value; + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) == 0) + { + _bits |= 262144L; + _headers._LastModified = value; + return true; + } + return false; + } + if ("Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) == 0) + { + _bits |= 8388608L; + _headers._Authorization = value; + return true; + } + return false; + } + if ("If-None-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) == 0) + { + _bits |= 1073741824L; + _headers._IfNoneMatch = value; + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) == 0) + { + _bits |= 2L; + _headers._Connection = value; + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) == 0) + { + _bits |= 8L; + _headers._KeepAlive = value; + return true; + } + return false; + } + if ("User-Agent".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 549755813888L) == 0) + { + _bits |= 549755813888L; + _headers._UserAgent = value; + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) == 0) + { + _bits |= 4L; + _headers._Date = value; + return true; + } + return false; + } + if ("From".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) == 0) + { + _bits |= 67108864L; + _headers._From = value; + return true; + } + return false; + } + if ("Host".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) == 0) + { + _bits |= 134217728L; + _headers._Host = value; + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) == 0) + { + _bits |= 16L; + _headers._Pragma = value; + return true; + } + return false; + } + if ("Accept".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) == 0) + { + _bits |= 524288L; + _headers._Accept = value; + return true; + } + return false; + } + if ("Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) == 0) + { + _bits |= 16777216L; + _headers._Cookie = value; + return true; + } + return false; + } + if ("Expect".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) == 0) + { + _bits |= 33554432L; + _headers._Expect = value; + return true; + } + return false; + } + if ("Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1099511627776L) == 0) + { + _bits |= 1099511627776L; + _headers._Origin = value; + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) == 0) + { + _bits |= 32L; + _headers._Trailer = value; + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) == 0) + { + _bits |= 128L; + _headers._Upgrade = value; + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) == 0) + { + _bits |= 512L; + _headers._Warning = value; + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) == 0) + { + _bits |= 131072L; + _headers._Expires = value; + return true; + } + return false; + } + if ("Referer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 34359738368L) == 0) + { + _bits |= 34359738368L; + _headers._Referer = value; + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) == 0) + { + _bits |= 64L; + _headers._TransferEncoding = value; + return true; + } + return false; + } + if ("If-Modified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) == 0) + { + _bits |= 536870912L; + _headers._IfModifiedSince = value; + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) == 0) + { + _bits |= 256L; + _headers._Via = value; + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) == 0) + { + _bits |= 1024L; + _headers._Allow = value; + return true; + } + return false; + } + if ("Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 68719476736L) == 0) + { + _bits |= 68719476736L; + _headers._Range = value; + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) == 0) + { + _bits |= 2048L; + _headers._ContentType = value; + return true; + } + return false; + } + if ("Max-Forwards".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) == 0) + { + _bits |= 8589934592L; + _headers._MaxForwards = value; + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) == 0) + { + _bits |= 4096L; + _headers._ContentEncoding = value; + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) == 0) + { + _bits |= 8192L; + _headers._ContentLanguage = value; + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) == 0) + { + _bits |= 16384L; + _headers._ContentLocation = value; + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) == 0) + { + _bits |= 32768L; + _headers._ContentMD5 = value; + return true; + } + return false; + } + } + break; + case 14: + { + if ("Accept-Charset".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) == 0) + { + _bits |= 1048576L; + _headers._AcceptCharset = value; + return true; + } + return false; + } + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (!_contentLength.HasValue) + { + _contentLength = ParseContentLength(value); + return true; + } + return false; + } + } + break; + case 15: + { + if ("Accept-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) == 0) + { + _bits |= 2097152L; + _headers._AcceptEncoding = value; + return true; + } + return false; + } + if ("Accept-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) == 0) + { + _bits |= 4194304L; + _headers._AcceptLanguage = value; + return true; + } + return false; + } + } + break; + case 8: + { + if ("If-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) == 0) + { + _bits |= 268435456L; + _headers._IfMatch = value; + return true; + } + return false; + } + if ("If-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) == 0) + { + _bits |= 2147483648L; + _headers._IfRange = value; + return true; + } + return false; + } + } + break; + case 19: + { + if ("If-Unmodified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) == 0) + { + _bits |= 4294967296L; + _headers._IfUnmodifiedSince = value; + return true; + } + return false; + } + if ("Proxy-Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) == 0) + { + _bits |= 17179869184L; + _headers._ProxyAuthorization = value; + return true; + } + return false; + } + } + break; + case 2: + { + if ("TE".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 137438953472L) == 0) + { + _bits |= 137438953472L; + _headers._TE = value; + return true; + } + return false; + } + } + break; + case 9: + { + if ("Translate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 274877906944L) == 0) + { + _bits |= 274877906944L; + _headers._Translate = value; + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Request-Method".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2199023255552L) == 0) + { + _bits |= 2199023255552L; + _headers._AccessControlRequestMethod = value; + return true; + } + return false; + } + } + break; + case 30: + { + if ("Access-Control-Request-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4398046511104L) == 0) + { + _bits |= 4398046511104L; + _headers._AccessControlRequestHeaders = value; + return true; + } + return false; + } + } + break; + } + + Unknown.Add(key, value); + // Return true, above will throw and exit for false + return true; + } + + protected override bool RemoveFast(string key) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) != 0) + { + _bits &= ~1L; + _headers._CacheControl = default(StringValues); + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) != 0) + { + _bits &= ~65536L; + _headers._ContentRange = default(StringValues); + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) != 0) + { + _bits &= ~262144L; + _headers._LastModified = default(StringValues); + return true; + } + return false; + } + if ("Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) != 0) + { + _bits &= ~8388608L; + _headers._Authorization = default(StringValues); + return true; + } + return false; + } + if ("If-None-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) != 0) + { + _bits &= ~1073741824L; + _headers._IfNoneMatch = default(StringValues); + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) != 0) + { + _bits &= ~2L; + _headers._Connection = default(StringValues); + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) != 0) + { + _bits &= ~8L; + _headers._KeepAlive = default(StringValues); + return true; + } + return false; + } + if ("User-Agent".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 549755813888L) != 0) + { + _bits &= ~549755813888L; + _headers._UserAgent = default(StringValues); + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) != 0) + { + _bits &= ~4L; + _headers._Date = default(StringValues); + return true; + } + return false; + } + if ("From".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) != 0) + { + _bits &= ~67108864L; + _headers._From = default(StringValues); + return true; + } + return false; + } + if ("Host".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) != 0) + { + _bits &= ~134217728L; + _headers._Host = default(StringValues); + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) != 0) + { + _bits &= ~16L; + _headers._Pragma = default(StringValues); + return true; + } + return false; + } + if ("Accept".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) != 0) + { + _bits &= ~524288L; + _headers._Accept = default(StringValues); + return true; + } + return false; + } + if ("Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) != 0) + { + _bits &= ~16777216L; + _headers._Cookie = default(StringValues); + return true; + } + return false; + } + if ("Expect".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) != 0) + { + _bits &= ~33554432L; + _headers._Expect = default(StringValues); + return true; + } + return false; + } + if ("Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1099511627776L) != 0) + { + _bits &= ~1099511627776L; + _headers._Origin = default(StringValues); + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) != 0) + { + _bits &= ~32L; + _headers._Trailer = default(StringValues); + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) != 0) + { + _bits &= ~128L; + _headers._Upgrade = default(StringValues); + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) != 0) + { + _bits &= ~512L; + _headers._Warning = default(StringValues); + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) != 0) + { + _bits &= ~131072L; + _headers._Expires = default(StringValues); + return true; + } + return false; + } + if ("Referer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 34359738368L) != 0) + { + _bits &= ~34359738368L; + _headers._Referer = default(StringValues); + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) != 0) + { + _bits &= ~64L; + _headers._TransferEncoding = default(StringValues); + return true; + } + return false; + } + if ("If-Modified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) != 0) + { + _bits &= ~536870912L; + _headers._IfModifiedSince = default(StringValues); + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) != 0) + { + _bits &= ~256L; + _headers._Via = default(StringValues); + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) != 0) + { + _bits &= ~1024L; + _headers._Allow = default(StringValues); + return true; + } + return false; + } + if ("Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 68719476736L) != 0) + { + _bits &= ~68719476736L; + _headers._Range = default(StringValues); + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) != 0) + { + _bits &= ~2048L; + _headers._ContentType = default(StringValues); + return true; + } + return false; + } + if ("Max-Forwards".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) != 0) + { + _bits &= ~8589934592L; + _headers._MaxForwards = default(StringValues); + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) != 0) + { + _bits &= ~4096L; + _headers._ContentEncoding = default(StringValues); + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) != 0) + { + _bits &= ~8192L; + _headers._ContentLanguage = default(StringValues); + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) != 0) + { + _bits &= ~16384L; + _headers._ContentLocation = default(StringValues); + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) != 0) + { + _bits &= ~32768L; + _headers._ContentMD5 = default(StringValues); + return true; + } + return false; + } + } + break; + case 14: + { + if ("Accept-Charset".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) != 0) + { + _bits &= ~1048576L; + _headers._AcceptCharset = default(StringValues); + return true; + } + return false; + } + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (_contentLength.HasValue) + { + _contentLength = null; + return true; + } + return false; + } + } + break; + case 15: + { + if ("Accept-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) != 0) + { + _bits &= ~2097152L; + _headers._AcceptEncoding = default(StringValues); + return true; + } + return false; + } + if ("Accept-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) != 0) + { + _bits &= ~4194304L; + _headers._AcceptLanguage = default(StringValues); + return true; + } + return false; + } + } + break; + case 8: + { + if ("If-Match".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) != 0) + { + _bits &= ~268435456L; + _headers._IfMatch = default(StringValues); + return true; + } + return false; + } + if ("If-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) != 0) + { + _bits &= ~2147483648L; + _headers._IfRange = default(StringValues); + return true; + } + return false; + } + } + break; + case 19: + { + if ("If-Unmodified-Since".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) != 0) + { + _bits &= ~4294967296L; + _headers._IfUnmodifiedSince = default(StringValues); + return true; + } + return false; + } + if ("Proxy-Authorization".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) != 0) + { + _bits &= ~17179869184L; + _headers._ProxyAuthorization = default(StringValues); + return true; + } + return false; + } + } + break; + case 2: + { + if ("TE".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 137438953472L) != 0) + { + _bits &= ~137438953472L; + _headers._TE = default(StringValues); + return true; + } + return false; + } + } + break; + case 9: + { + if ("Translate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 274877906944L) != 0) + { + _bits &= ~274877906944L; + _headers._Translate = default(StringValues); + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Request-Method".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2199023255552L) != 0) + { + _bits &= ~2199023255552L; + _headers._AccessControlRequestMethod = default(StringValues); + return true; + } + return false; + } + } + break; + case 30: + { + if ("Access-Control-Request-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4398046511104L) != 0) + { + _bits &= ~4398046511104L; + _headers._AccessControlRequestHeaders = default(StringValues); + return true; + } + return false; + } + } + break; + } + + return MaybeUnknown?.Remove(key) ?? false; + } + + protected override void ClearFast() + { + MaybeUnknown?.Clear(); + _contentLength = null; + var tempBits = _bits; + _bits = 0; + if(HttpHeaders.BitCount(tempBits) > 12) + { + _headers = default(HeaderReferences); + return; + } + + if ((tempBits & 2L) != 0) + { + _headers._Connection = default(StringValues); + if((tempBits & ~2L) == 0) + { + return; + } + tempBits &= ~2L; + } + + if ((tempBits & 524288L) != 0) + { + _headers._Accept = default(StringValues); + if((tempBits & ~524288L) == 0) + { + return; + } + tempBits &= ~524288L; + } + + if ((tempBits & 134217728L) != 0) + { + _headers._Host = default(StringValues); + if((tempBits & ~134217728L) == 0) + { + return; + } + tempBits &= ~134217728L; + } + + if ((tempBits & 549755813888L) != 0) + { + _headers._UserAgent = default(StringValues); + if((tempBits & ~549755813888L) == 0) + { + return; + } + tempBits &= ~549755813888L; + } + + if ((tempBits & 1L) != 0) + { + _headers._CacheControl = default(StringValues); + if((tempBits & ~1L) == 0) + { + return; + } + tempBits &= ~1L; + } + + if ((tempBits & 4L) != 0) + { + _headers._Date = default(StringValues); + if((tempBits & ~4L) == 0) + { + return; + } + tempBits &= ~4L; + } + + if ((tempBits & 8L) != 0) + { + _headers._KeepAlive = default(StringValues); + if((tempBits & ~8L) == 0) + { + return; + } + tempBits &= ~8L; + } + + if ((tempBits & 16L) != 0) + { + _headers._Pragma = default(StringValues); + if((tempBits & ~16L) == 0) + { + return; + } + tempBits &= ~16L; + } + + if ((tempBits & 32L) != 0) + { + _headers._Trailer = default(StringValues); + if((tempBits & ~32L) == 0) + { + return; + } + tempBits &= ~32L; + } + + if ((tempBits & 64L) != 0) + { + _headers._TransferEncoding = default(StringValues); + if((tempBits & ~64L) == 0) + { + return; + } + tempBits &= ~64L; + } + + if ((tempBits & 128L) != 0) + { + _headers._Upgrade = default(StringValues); + if((tempBits & ~128L) == 0) + { + return; + } + tempBits &= ~128L; + } + + if ((tempBits & 256L) != 0) + { + _headers._Via = default(StringValues); + if((tempBits & ~256L) == 0) + { + return; + } + tempBits &= ~256L; + } + + if ((tempBits & 512L) != 0) + { + _headers._Warning = default(StringValues); + if((tempBits & ~512L) == 0) + { + return; + } + tempBits &= ~512L; + } + + if ((tempBits & 1024L) != 0) + { + _headers._Allow = default(StringValues); + if((tempBits & ~1024L) == 0) + { + return; + } + tempBits &= ~1024L; + } + + if ((tempBits & 2048L) != 0) + { + _headers._ContentType = default(StringValues); + if((tempBits & ~2048L) == 0) + { + return; + } + tempBits &= ~2048L; + } + + if ((tempBits & 4096L) != 0) + { + _headers._ContentEncoding = default(StringValues); + if((tempBits & ~4096L) == 0) + { + return; + } + tempBits &= ~4096L; + } + + if ((tempBits & 8192L) != 0) + { + _headers._ContentLanguage = default(StringValues); + if((tempBits & ~8192L) == 0) + { + return; + } + tempBits &= ~8192L; + } + + if ((tempBits & 16384L) != 0) + { + _headers._ContentLocation = default(StringValues); + if((tempBits & ~16384L) == 0) + { + return; + } + tempBits &= ~16384L; + } + + if ((tempBits & 32768L) != 0) + { + _headers._ContentMD5 = default(StringValues); + if((tempBits & ~32768L) == 0) + { + return; + } + tempBits &= ~32768L; + } + + if ((tempBits & 65536L) != 0) + { + _headers._ContentRange = default(StringValues); + if((tempBits & ~65536L) == 0) + { + return; + } + tempBits &= ~65536L; + } + + if ((tempBits & 131072L) != 0) + { + _headers._Expires = default(StringValues); + if((tempBits & ~131072L) == 0) + { + return; + } + tempBits &= ~131072L; + } + + if ((tempBits & 262144L) != 0) + { + _headers._LastModified = default(StringValues); + if((tempBits & ~262144L) == 0) + { + return; + } + tempBits &= ~262144L; + } + + if ((tempBits & 1048576L) != 0) + { + _headers._AcceptCharset = default(StringValues); + if((tempBits & ~1048576L) == 0) + { + return; + } + tempBits &= ~1048576L; + } + + if ((tempBits & 2097152L) != 0) + { + _headers._AcceptEncoding = default(StringValues); + if((tempBits & ~2097152L) == 0) + { + return; + } + tempBits &= ~2097152L; + } + + if ((tempBits & 4194304L) != 0) + { + _headers._AcceptLanguage = default(StringValues); + if((tempBits & ~4194304L) == 0) + { + return; + } + tempBits &= ~4194304L; + } + + if ((tempBits & 8388608L) != 0) + { + _headers._Authorization = default(StringValues); + if((tempBits & ~8388608L) == 0) + { + return; + } + tempBits &= ~8388608L; + } + + if ((tempBits & 16777216L) != 0) + { + _headers._Cookie = default(StringValues); + if((tempBits & ~16777216L) == 0) + { + return; + } + tempBits &= ~16777216L; + } + + if ((tempBits & 33554432L) != 0) + { + _headers._Expect = default(StringValues); + if((tempBits & ~33554432L) == 0) + { + return; + } + tempBits &= ~33554432L; + } + + if ((tempBits & 67108864L) != 0) + { + _headers._From = default(StringValues); + if((tempBits & ~67108864L) == 0) + { + return; + } + tempBits &= ~67108864L; + } + + if ((tempBits & 268435456L) != 0) + { + _headers._IfMatch = default(StringValues); + if((tempBits & ~268435456L) == 0) + { + return; + } + tempBits &= ~268435456L; + } + + if ((tempBits & 536870912L) != 0) + { + _headers._IfModifiedSince = default(StringValues); + if((tempBits & ~536870912L) == 0) + { + return; + } + tempBits &= ~536870912L; + } + + if ((tempBits & 1073741824L) != 0) + { + _headers._IfNoneMatch = default(StringValues); + if((tempBits & ~1073741824L) == 0) + { + return; + } + tempBits &= ~1073741824L; + } + + if ((tempBits & 2147483648L) != 0) + { + _headers._IfRange = default(StringValues); + if((tempBits & ~2147483648L) == 0) + { + return; + } + tempBits &= ~2147483648L; + } + + if ((tempBits & 4294967296L) != 0) + { + _headers._IfUnmodifiedSince = default(StringValues); + if((tempBits & ~4294967296L) == 0) + { + return; + } + tempBits &= ~4294967296L; + } + + if ((tempBits & 8589934592L) != 0) + { + _headers._MaxForwards = default(StringValues); + if((tempBits & ~8589934592L) == 0) + { + return; + } + tempBits &= ~8589934592L; + } + + if ((tempBits & 17179869184L) != 0) + { + _headers._ProxyAuthorization = default(StringValues); + if((tempBits & ~17179869184L) == 0) + { + return; + } + tempBits &= ~17179869184L; + } + + if ((tempBits & 34359738368L) != 0) + { + _headers._Referer = default(StringValues); + if((tempBits & ~34359738368L) == 0) + { + return; + } + tempBits &= ~34359738368L; + } + + if ((tempBits & 68719476736L) != 0) + { + _headers._Range = default(StringValues); + if((tempBits & ~68719476736L) == 0) + { + return; + } + tempBits &= ~68719476736L; + } + + if ((tempBits & 137438953472L) != 0) + { + _headers._TE = default(StringValues); + if((tempBits & ~137438953472L) == 0) + { + return; + } + tempBits &= ~137438953472L; + } + + if ((tempBits & 274877906944L) != 0) + { + _headers._Translate = default(StringValues); + if((tempBits & ~274877906944L) == 0) + { + return; + } + tempBits &= ~274877906944L; + } + + if ((tempBits & 1099511627776L) != 0) + { + _headers._Origin = default(StringValues); + if((tempBits & ~1099511627776L) == 0) + { + return; + } + tempBits &= ~1099511627776L; + } + + if ((tempBits & 2199023255552L) != 0) + { + _headers._AccessControlRequestMethod = default(StringValues); + if((tempBits & ~2199023255552L) == 0) + { + return; + } + tempBits &= ~2199023255552L; + } + + if ((tempBits & 4398046511104L) != 0) + { + _headers._AccessControlRequestHeaders = default(StringValues); + if((tempBits & ~4398046511104L) == 0) + { + return; + } + tempBits &= ~4398046511104L; + } + + } + + protected override bool CopyToFast(KeyValuePair[] array, int arrayIndex) + { + if (arrayIndex < 0) + { + return false; + } + + if ((_bits & 1L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Cache-Control", _headers._CacheControl); + ++arrayIndex; + } + if ((_bits & 2L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Connection", _headers._Connection); + ++arrayIndex; + } + if ((_bits & 4L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Date", _headers._Date); + ++arrayIndex; + } + if ((_bits & 8L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Keep-Alive", _headers._KeepAlive); + ++arrayIndex; + } + if ((_bits & 16L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Pragma", _headers._Pragma); + ++arrayIndex; + } + if ((_bits & 32L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Trailer", _headers._Trailer); + ++arrayIndex; + } + if ((_bits & 64L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Transfer-Encoding", _headers._TransferEncoding); + ++arrayIndex; + } + if ((_bits & 128L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Upgrade", _headers._Upgrade); + ++arrayIndex; + } + if ((_bits & 256L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Via", _headers._Via); + ++arrayIndex; + } + if ((_bits & 512L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Warning", _headers._Warning); + ++arrayIndex; + } + if ((_bits & 1024L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Allow", _headers._Allow); + ++arrayIndex; + } + if ((_bits & 2048L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Type", _headers._ContentType); + ++arrayIndex; + } + if ((_bits & 4096L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Encoding", _headers._ContentEncoding); + ++arrayIndex; + } + if ((_bits & 8192L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Language", _headers._ContentLanguage); + ++arrayIndex; + } + if ((_bits & 16384L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Location", _headers._ContentLocation); + ++arrayIndex; + } + if ((_bits & 32768L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-MD5", _headers._ContentMD5); + ++arrayIndex; + } + if ((_bits & 65536L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Range", _headers._ContentRange); + ++arrayIndex; + } + if ((_bits & 131072L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Expires", _headers._Expires); + ++arrayIndex; + } + if ((_bits & 262144L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Last-Modified", _headers._LastModified); + ++arrayIndex; + } + if ((_bits & 524288L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Accept", _headers._Accept); + ++arrayIndex; + } + if ((_bits & 1048576L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Accept-Charset", _headers._AcceptCharset); + ++arrayIndex; + } + if ((_bits & 2097152L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Accept-Encoding", _headers._AcceptEncoding); + ++arrayIndex; + } + if ((_bits & 4194304L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Accept-Language", _headers._AcceptLanguage); + ++arrayIndex; + } + if ((_bits & 8388608L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Authorization", _headers._Authorization); + ++arrayIndex; + } + if ((_bits & 16777216L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Cookie", _headers._Cookie); + ++arrayIndex; + } + if ((_bits & 33554432L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Expect", _headers._Expect); + ++arrayIndex; + } + if ((_bits & 67108864L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("From", _headers._From); + ++arrayIndex; + } + if ((_bits & 134217728L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Host", _headers._Host); + ++arrayIndex; + } + if ((_bits & 268435456L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("If-Match", _headers._IfMatch); + ++arrayIndex; + } + if ((_bits & 536870912L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("If-Modified-Since", _headers._IfModifiedSince); + ++arrayIndex; + } + if ((_bits & 1073741824L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("If-None-Match", _headers._IfNoneMatch); + ++arrayIndex; + } + if ((_bits & 2147483648L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("If-Range", _headers._IfRange); + ++arrayIndex; + } + if ((_bits & 4294967296L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("If-Unmodified-Since", _headers._IfUnmodifiedSince); + ++arrayIndex; + } + if ((_bits & 8589934592L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Max-Forwards", _headers._MaxForwards); + ++arrayIndex; + } + if ((_bits & 17179869184L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Proxy-Authorization", _headers._ProxyAuthorization); + ++arrayIndex; + } + if ((_bits & 34359738368L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Referer", _headers._Referer); + ++arrayIndex; + } + if ((_bits & 68719476736L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Range", _headers._Range); + ++arrayIndex; + } + if ((_bits & 137438953472L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("TE", _headers._TE); + ++arrayIndex; + } + if ((_bits & 274877906944L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Translate", _headers._Translate); + ++arrayIndex; + } + if ((_bits & 549755813888L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("User-Agent", _headers._UserAgent); + ++arrayIndex; + } + if ((_bits & 1099511627776L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Origin", _headers._Origin); + ++arrayIndex; + } + if ((_bits & 2199023255552L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Request-Method", _headers._AccessControlRequestMethod); + ++arrayIndex; + } + if ((_bits & 4398046511104L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Request-Headers", _headers._AccessControlRequestHeaders); + ++arrayIndex; + } + if (_contentLength.HasValue) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Length", HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + ++arrayIndex; + } + ((ICollection>)MaybeUnknown)?.CopyTo(array, arrayIndex); + + return true; + } + + + public unsafe void Append(byte* pKeyBytes, int keyLength, string value) + { + var pUB = pKeyBytes; + var pUL = (ulong*)pUB; + var pUI = (uint*)pUB; + var pUS = (ushort*)pUB; + var stringValue = new StringValues(value); + switch (keyLength) + { + case 10: + { + if ((((pUL[0] & 16131858542891098079uL) == 5283922227757993795uL) && ((pUS[4] & 57311u) == 20047u))) + { + if ((_bits & 2L) != 0) + { + _headers._Connection = AppendValue(_headers._Connection, value); + } + else + { + _bits |= 2L; + _headers._Connection = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858680330051551uL) == 4992030374873092949uL) && ((pUS[4] & 57311u) == 21582u))) + { + if ((_bits & 549755813888L) != 0) + { + _headers._UserAgent = AppendValue(_headers._UserAgent, value); + } + else + { + _bits |= 549755813888L; + _headers._UserAgent = stringValue; + } + return; + } + } + break; + + case 6: + { + if ((((pUI[0] & 3755991007u) == 1162036033u) && ((pUS[2] & 57311u) == 21584u))) + { + if ((_bits & 524288L) != 0) + { + _headers._Accept = AppendValue(_headers._Accept, value); + } + else + { + _bits |= 524288L; + _headers._Accept = stringValue; + } + return; + } + } + break; + + case 4: + { + if ((((pUI[0] & 3755991007u) == 1414745928u))) + { + if ((_bits & 134217728L) != 0) + { + _headers._Host = AppendValue(_headers._Host, value); + } + else + { + _bits |= 134217728L; + _headers._Host = stringValue; + } + return; + } + } + break; + } + + AppendNonPrimaryHeaders(pKeyBytes, keyLength, value); + } + + private unsafe void AppendNonPrimaryHeaders(byte* pKeyBytes, int keyLength, string value) + { + var pUB = pKeyBytes; + var pUL = (ulong*)pUB; + var pUI = (uint*)pUB; + var pUS = (ushort*)pUB; + var stringValue = new StringValues(value); + switch (keyLength) + { + case 13: + { + if ((((pUL[0] & 16131893727263186911uL) == 5711458528024281411uL) && ((pUI[2] & 3755991007u) == 1330795598u) && ((pUB[12] & 223u) == 76u))) + { + if ((_bits & 1L) != 0) + { + _headers._CacheControl = AppendValue(_headers._CacheControl, value); + } + else + { + _bits |= 1L; + _headers._CacheControl = stringValue; + } + return; + } + + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUI[2] & 3755991007u) == 1196310866u) && ((pUB[12] & 223u) == 69u))) + { + if ((_bits & 65536L) != 0) + { + _headers._ContentRange = AppendValue(_headers._ContentRange, value); + } + else + { + _bits |= 65536L; + _headers._ContentRange = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858680330051551uL) == 4922237774822850892uL) && ((pUI[2] & 3755991007u) == 1162430025u) && ((pUB[12] & 223u) == 68u))) + { + if ((_bits & 262144L) != 0) + { + _headers._LastModified = AppendValue(_headers._LastModified, value); + } + else + { + _bits |= 262144L; + _headers._LastModified = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858542891098079uL) == 6505821637182772545uL) && ((pUI[2] & 3755991007u) == 1330205761u) && ((pUB[12] & 223u) == 78u))) + { + if ((_bits & 8388608L) != 0) + { + _headers._Authorization = AppendValue(_headers._Authorization, value); + } + else + { + _bits |= 8388608L; + _headers._Authorization = stringValue; + } + return; + } + + if ((((pUL[0] & 18437701552106889183uL) == 3262099607620765257uL) && ((pUI[2] & 3755991007u) == 1129595213u) && ((pUB[12] & 223u) == 72u))) + { + if ((_bits & 1073741824L) != 0) + { + _headers._IfNoneMatch = AppendValue(_headers._IfNoneMatch, value); + } + else + { + _bits |= 1073741824L; + _headers._IfNoneMatch = stringValue; + } + return; + } + } + break; + + case 4: + { + if ((((pUI[0] & 3755991007u) == 1163149636u))) + { + if ((_bits & 4L) != 0) + { + _headers._Date = AppendValue(_headers._Date, value); + } + else + { + _bits |= 4L; + _headers._Date = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1297044038u))) + { + if ((_bits & 67108864L) != 0) + { + _headers._From = AppendValue(_headers._From, value); + } + else + { + _bits |= 67108864L; + _headers._From = stringValue; + } + return; + } + } + break; + + case 10: + { + if ((((pUL[0] & 16131858680330051551uL) == 5281668125874799947uL) && ((pUS[4] & 57311u) == 17750u))) + { + if ((_bits & 8L) != 0) + { + _headers._KeepAlive = AppendValue(_headers._KeepAlive, value); + } + else + { + _bits |= 8L; + _headers._KeepAlive = stringValue; + } + return; + } + } + break; + + case 6: + { + if ((((pUI[0] & 3755991007u) == 1195463248u) && ((pUS[2] & 57311u) == 16717u))) + { + if ((_bits & 16L) != 0) + { + _headers._Pragma = AppendValue(_headers._Pragma, value); + } + else + { + _bits |= 16L; + _headers._Pragma = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1263488835u) && ((pUS[2] & 57311u) == 17737u))) + { + if ((_bits & 16777216L) != 0) + { + _headers._Cookie = AppendValue(_headers._Cookie, value); + } + else + { + _bits |= 16777216L; + _headers._Cookie = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1162893381u) && ((pUS[2] & 57311u) == 21571u))) + { + if ((_bits & 33554432L) != 0) + { + _headers._Expect = AppendValue(_headers._Expect, value); + } + else + { + _bits |= 33554432L; + _headers._Expect = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1195987535u) && ((pUS[2] & 57311u) == 20041u))) + { + if ((_bits & 1099511627776L) != 0) + { + _headers._Origin = AppendValue(_headers._Origin, value); + } + else + { + _bits |= 1099511627776L; + _headers._Origin = stringValue; + } + return; + } + } + break; + + case 7: + { + if ((((pUI[0] & 3755991007u) == 1229017684u) && ((pUS[2] & 57311u) == 17740u) && ((pUB[6] & 223u) == 82u))) + { + if ((_bits & 32L) != 0) + { + _headers._Trailer = AppendValue(_headers._Trailer, value); + } + else + { + _bits |= 32L; + _headers._Trailer = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1380405333u) && ((pUS[2] & 57311u) == 17473u) && ((pUB[6] & 223u) == 69u))) + { + if ((_bits & 128L) != 0) + { + _headers._Upgrade = AppendValue(_headers._Upgrade, value); + } + else + { + _bits |= 128L; + _headers._Upgrade = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1314013527u) && ((pUS[2] & 57311u) == 20041u) && ((pUB[6] & 223u) == 71u))) + { + if ((_bits & 512L) != 0) + { + _headers._Warning = AppendValue(_headers._Warning, value); + } + else + { + _bits |= 512L; + _headers._Warning = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1230002245u) && ((pUS[2] & 57311u) == 17746u) && ((pUB[6] & 223u) == 83u))) + { + if ((_bits & 131072L) != 0) + { + _headers._Expires = AppendValue(_headers._Expires, value); + } + else + { + _bits |= 131072L; + _headers._Expires = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1162233170u) && ((pUS[2] & 57311u) == 17746u) && ((pUB[6] & 223u) == 82u))) + { + if ((_bits & 34359738368L) != 0) + { + _headers._Referer = AppendValue(_headers._Referer, value); + } + else + { + _bits |= 34359738368L; + _headers._Referer = stringValue; + } + return; + } + } + break; + + case 17: + { + if ((((pUL[0] & 16131858542891098079uL) == 5928221808112259668uL) && ((pUL[1] & 16131858542891098111uL) == 5641115115480565037uL) && ((pUB[16] & 223u) == 71u))) + { + if ((_bits & 64L) != 0) + { + _headers._TransferEncoding = AppendValue(_headers._TransferEncoding, value); + } + else + { + _bits |= 64L; + _headers._TransferEncoding = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858542893195231uL) == 5064654363342751305uL) && ((pUL[1] & 16131858543427968991uL) == 4849894470315165001uL) && ((pUB[16] & 223u) == 69u))) + { + if ((_bits & 536870912L) != 0) + { + _headers._IfModifiedSince = AppendValue(_headers._IfModifiedSince, value); + } + else + { + _bits |= 536870912L; + _headers._IfModifiedSince = stringValue; + } + return; + } + } + break; + + case 3: + { + if ((((pUS[0] & 57311u) == 18774u) && ((pUB[2] & 223u) == 65u))) + { + if ((_bits & 256L) != 0) + { + _headers._Via = AppendValue(_headers._Via, value); + } + else + { + _bits |= 256L; + _headers._Via = stringValue; + } + return; + } + } + break; + + case 5: + { + if ((((pUI[0] & 3755991007u) == 1330400321u) && ((pUB[4] & 223u) == 87u))) + { + if ((_bits & 1024L) != 0) + { + _headers._Allow = AppendValue(_headers._Allow, value); + } + else + { + _bits |= 1024L; + _headers._Allow = stringValue; + } + return; + } + + if ((((pUI[0] & 3755991007u) == 1196310866u) && ((pUB[4] & 223u) == 69u))) + { + if ((_bits & 68719476736L) != 0) + { + _headers._Range = AppendValue(_headers._Range, value); + } + else + { + _bits |= 68719476736L; + _headers._Range = stringValue; + } + return; + } + } + break; + + case 12: + { + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUI[2] & 3755991007u) == 1162893652u))) + { + if ((_bits & 2048L) != 0) + { + _headers._ContentType = AppendValue(_headers._ContentType, value); + } + else + { + _bits |= 2048L; + _headers._ContentType = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858543427968991uL) == 6292178792217067853uL) && ((pUI[2] & 3755991007u) == 1396986433u))) + { + if ((_bits & 8589934592L) != 0) + { + _headers._MaxForwards = AppendValue(_headers._MaxForwards, value); + } + else + { + _bits |= 8589934592L; + _headers._MaxForwards = stringValue; + } + return; + } + } + break; + + case 16: + { + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUL[1] & 16131858542891098079uL) == 5138124782612729413uL))) + { + if ((_bits & 4096L) != 0) + { + _headers._ContentEncoding = AppendValue(_headers._ContentEncoding, value); + } + else + { + _bits |= 4096L; + _headers._ContentEncoding = stringValue; + } + return; + } + + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUL[1] & 16131858542891098079uL) == 4992030546487820620uL))) + { + if ((_bits & 8192L) != 0) + { + _headers._ContentLanguage = AppendValue(_headers._ContentLanguage, value); + } + else + { + _bits |= 8192L; + _headers._ContentLanguage = stringValue; + } + return; + } + + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUL[1] & 16131858542891098079uL) == 5642809484339531596uL))) + { + if ((_bits & 16384L) != 0) + { + _headers._ContentLocation = AppendValue(_headers._ContentLocation, value); + } + else + { + _bits |= 16384L; + _headers._ContentLocation = stringValue; + } + return; + } + } + break; + + case 11: + { + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUS[4] & 57311u) == 17485u) && ((pUB[10] & 255u) == 53u))) + { + if ((_bits & 32768L) != 0) + { + _headers._ContentMD5 = AppendValue(_headers._ContentMD5, value); + } + else + { + _bits |= 32768L; + _headers._ContentMD5 = stringValue; + } + return; + } + } + break; + + case 14: + { + if ((((pUL[0] & 16140865742145839071uL) == 4840617878229304129uL) && ((pUI[2] & 3755991007u) == 1397899592u) && ((pUS[6] & 57311u) == 21573u))) + { + if ((_bits & 1048576L) != 0) + { + _headers._AcceptCharset = AppendValue(_headers._AcceptCharset, value); + } + else + { + _bits |= 1048576L; + _headers._AcceptCharset = stringValue; + } + return; + } + + if ((((pUL[0] & 18437701552104792031uL) == 3266321689424580419uL) && ((pUI[2] & 3755991007u) == 1196311884u) && ((pUS[6] & 57311u) == 18516u))) + { + if (_contentLength.HasValue) + { + BadHttpRequestException.Throw(RequestRejectionReason.MultipleContentLengths); + } + else + { + _contentLength = ParseContentLength(value); + } + return; + } + } + break; + + case 15: + { + if ((((pUL[0] & 16140865742145839071uL) == 4984733066305160001uL) && ((pUI[2] & 3755991007u) == 1146045262u) && ((pUS[6] & 57311u) == 20041u) && ((pUB[14] & 223u) == 71u))) + { + if ((_bits & 2097152L) != 0) + { + _headers._AcceptEncoding = AppendValue(_headers._AcceptEncoding, value); + } + else + { + _bits |= 2097152L; + _headers._AcceptEncoding = stringValue; + } + return; + } + + if ((((pUL[0] & 16140865742145839071uL) == 5489136224570655553uL) && ((pUI[2] & 3755991007u) == 1430736449u) && ((pUS[6] & 57311u) == 18241u) && ((pUB[14] & 223u) == 69u))) + { + if ((_bits & 4194304L) != 0) + { + _headers._AcceptLanguage = AppendValue(_headers._AcceptLanguage, value); + } + else + { + _bits |= 4194304L; + _headers._AcceptLanguage = stringValue; + } + return; + } + } + break; + + case 8: + { + if ((((pUL[0] & 16131858542893195231uL) == 5207098233614845513uL))) + { + if ((_bits & 268435456L) != 0) + { + _headers._IfMatch = AppendValue(_headers._IfMatch, value); + } + else + { + _bits |= 268435456L; + _headers._IfMatch = stringValue; + } + return; + } + + if ((((pUL[0] & 16131858542893195231uL) == 4992044754422023753uL))) + { + if ((_bits & 2147483648L) != 0) + { + _headers._IfRange = AppendValue(_headers._IfRange, value); + } + else + { + _bits |= 2147483648L; + _headers._IfRange = stringValue; + } + return; + } + } + break; + + case 19: + { + if ((((pUL[0] & 16131858542893195231uL) == 4922237916571059785uL) && ((pUL[1] & 16131893727263186911uL) == 5283616559079179849uL) && ((pUS[8] & 57311u) == 17230u) && ((pUB[18] & 223u) == 69u))) + { + if ((_bits & 4294967296L) != 0) + { + _headers._IfUnmodifiedSince = AppendValue(_headers._IfUnmodifiedSince, value); + } + else + { + _bits |= 4294967296L; + _headers._IfUnmodifiedSince = stringValue; + } + return; + } + + if ((((pUL[0] & 16131893727263186911uL) == 6143241228466999888uL) && ((pUL[1] & 16131858542891098079uL) == 6071233043632179284uL) && ((pUS[8] & 57311u) == 20297u) && ((pUB[18] & 223u) == 78u))) + { + if ((_bits & 17179869184L) != 0) + { + _headers._ProxyAuthorization = AppendValue(_headers._ProxyAuthorization, value); + } + else + { + _bits |= 17179869184L; + _headers._ProxyAuthorization = stringValue; + } + return; + } + } + break; + + case 2: + { + if ((((pUS[0] & 57311u) == 17748u))) + { + if ((_bits & 137438953472L) != 0) + { + _headers._TE = AppendValue(_headers._TE, value); + } + else + { + _bits |= 137438953472L; + _headers._TE = stringValue; + } + return; + } + } + break; + + case 9: + { + if ((((pUL[0] & 16131858542891098079uL) == 6071217693351039572uL) && ((pUB[8] & 223u) == 69u))) + { + if ((_bits & 274877906944L) != 0) + { + _headers._Translate = AppendValue(_headers._Translate, value); + } + else + { + _bits |= 274877906944L; + _headers._Translate = stringValue; + } + return; + } + } + break; + + case 29: + { + if ((((pUL[0] & 16140865742145839071uL) == 4840616791602578241uL) && ((pUL[1] & 16140865742145839071uL) == 5921472988629454415uL) && ((pUL[2] & 16140865742145839071uL) == 5561193831494668613uL) && ((pUI[6] & 3755991007u) == 1330140229u) && ((pUB[28] & 223u) == 68u))) + { + if ((_bits & 2199023255552L) != 0) + { + _headers._AccessControlRequestMethod = AppendValue(_headers._AccessControlRequestMethod, value); + } + else + { + _bits |= 2199023255552L; + _headers._AccessControlRequestMethod = stringValue; + } + return; + } + } + break; + + case 30: + { + if ((((pUL[0] & 16140865742145839071uL) == 4840616791602578241uL) && ((pUL[1] & 16140865742145839071uL) == 5921472988629454415uL) && ((pUL[2] & 16140865742145839071uL) == 5200905861305028933uL) && ((pUI[6] & 3755991007u) == 1162101061u) && ((pUS[14] & 57311u) == 21330u))) + { + if ((_bits & 4398046511104L) != 0) + { + _headers._AccessControlRequestHeaders = AppendValue(_headers._AccessControlRequestHeaders, value); + } + else + { + _bits |= 4398046511104L; + _headers._AccessControlRequestHeaders = stringValue; + } + return; + } + } + break; + } + + AppendUnknownHeaders(pKeyBytes, keyLength, value); + } + + private struct HeaderReferences + { + public StringValues _CacheControl; + public StringValues _Connection; + public StringValues _Date; + public StringValues _KeepAlive; + public StringValues _Pragma; + public StringValues _Trailer; + public StringValues _TransferEncoding; + public StringValues _Upgrade; + public StringValues _Via; + public StringValues _Warning; + public StringValues _Allow; + public StringValues _ContentType; + public StringValues _ContentEncoding; + public StringValues _ContentLanguage; + public StringValues _ContentLocation; + public StringValues _ContentMD5; + public StringValues _ContentRange; + public StringValues _Expires; + public StringValues _LastModified; + public StringValues _Accept; + public StringValues _AcceptCharset; + public StringValues _AcceptEncoding; + public StringValues _AcceptLanguage; + public StringValues _Authorization; + public StringValues _Cookie; + public StringValues _Expect; + public StringValues _From; + public StringValues _Host; + public StringValues _IfMatch; + public StringValues _IfModifiedSince; + public StringValues _IfNoneMatch; + public StringValues _IfRange; + public StringValues _IfUnmodifiedSince; + public StringValues _MaxForwards; + public StringValues _ProxyAuthorization; + public StringValues _Referer; + public StringValues _Range; + public StringValues _TE; + public StringValues _Translate; + public StringValues _UserAgent; + public StringValues _Origin; + public StringValues _AccessControlRequestMethod; + public StringValues _AccessControlRequestHeaders; + + } + + public partial struct Enumerator + { + public bool MoveNext() + { + switch (_state) + { + + case 0: + goto state0; + + case 1: + goto state1; + + case 2: + goto state2; + + case 3: + goto state3; + + case 4: + goto state4; + + case 5: + goto state5; + + case 6: + goto state6; + + case 7: + goto state7; + + case 8: + goto state8; + + case 9: + goto state9; + + case 10: + goto state10; + + case 11: + goto state11; + + case 12: + goto state12; + + case 13: + goto state13; + + case 14: + goto state14; + + case 15: + goto state15; + + case 16: + goto state16; + + case 17: + goto state17; + + case 18: + goto state18; + + case 19: + goto state19; + + case 20: + goto state20; + + case 21: + goto state21; + + case 22: + goto state22; + + case 23: + goto state23; + + case 24: + goto state24; + + case 25: + goto state25; + + case 26: + goto state26; + + case 27: + goto state27; + + case 28: + goto state28; + + case 29: + goto state29; + + case 30: + goto state30; + + case 31: + goto state31; + + case 32: + goto state32; + + case 33: + goto state33; + + case 34: + goto state34; + + case 35: + goto state35; + + case 36: + goto state36; + + case 37: + goto state37; + + case 38: + goto state38; + + case 39: + goto state39; + + case 40: + goto state40; + + case 41: + goto state41; + + case 42: + goto state42; + + case 44: + goto state44; + default: + goto state_default; + } + + state0: + if ((_bits & 1L) != 0) + { + _current = new KeyValuePair("Cache-Control", _collection._headers._CacheControl); + _state = 1; + return true; + } + + state1: + if ((_bits & 2L) != 0) + { + _current = new KeyValuePair("Connection", _collection._headers._Connection); + _state = 2; + return true; + } + + state2: + if ((_bits & 4L) != 0) + { + _current = new KeyValuePair("Date", _collection._headers._Date); + _state = 3; + return true; + } + + state3: + if ((_bits & 8L) != 0) + { + _current = new KeyValuePair("Keep-Alive", _collection._headers._KeepAlive); + _state = 4; + return true; + } + + state4: + if ((_bits & 16L) != 0) + { + _current = new KeyValuePair("Pragma", _collection._headers._Pragma); + _state = 5; + return true; + } + + state5: + if ((_bits & 32L) != 0) + { + _current = new KeyValuePair("Trailer", _collection._headers._Trailer); + _state = 6; + return true; + } + + state6: + if ((_bits & 64L) != 0) + { + _current = new KeyValuePair("Transfer-Encoding", _collection._headers._TransferEncoding); + _state = 7; + return true; + } + + state7: + if ((_bits & 128L) != 0) + { + _current = new KeyValuePair("Upgrade", _collection._headers._Upgrade); + _state = 8; + return true; + } + + state8: + if ((_bits & 256L) != 0) + { + _current = new KeyValuePair("Via", _collection._headers._Via); + _state = 9; + return true; + } + + state9: + if ((_bits & 512L) != 0) + { + _current = new KeyValuePair("Warning", _collection._headers._Warning); + _state = 10; + return true; + } + + state10: + if ((_bits & 1024L) != 0) + { + _current = new KeyValuePair("Allow", _collection._headers._Allow); + _state = 11; + return true; + } + + state11: + if ((_bits & 2048L) != 0) + { + _current = new KeyValuePair("Content-Type", _collection._headers._ContentType); + _state = 12; + return true; + } + + state12: + if ((_bits & 4096L) != 0) + { + _current = new KeyValuePair("Content-Encoding", _collection._headers._ContentEncoding); + _state = 13; + return true; + } + + state13: + if ((_bits & 8192L) != 0) + { + _current = new KeyValuePair("Content-Language", _collection._headers._ContentLanguage); + _state = 14; + return true; + } + + state14: + if ((_bits & 16384L) != 0) + { + _current = new KeyValuePair("Content-Location", _collection._headers._ContentLocation); + _state = 15; + return true; + } + + state15: + if ((_bits & 32768L) != 0) + { + _current = new KeyValuePair("Content-MD5", _collection._headers._ContentMD5); + _state = 16; + return true; + } + + state16: + if ((_bits & 65536L) != 0) + { + _current = new KeyValuePair("Content-Range", _collection._headers._ContentRange); + _state = 17; + return true; + } + + state17: + if ((_bits & 131072L) != 0) + { + _current = new KeyValuePair("Expires", _collection._headers._Expires); + _state = 18; + return true; + } + + state18: + if ((_bits & 262144L) != 0) + { + _current = new KeyValuePair("Last-Modified", _collection._headers._LastModified); + _state = 19; + return true; + } + + state19: + if ((_bits & 524288L) != 0) + { + _current = new KeyValuePair("Accept", _collection._headers._Accept); + _state = 20; + return true; + } + + state20: + if ((_bits & 1048576L) != 0) + { + _current = new KeyValuePair("Accept-Charset", _collection._headers._AcceptCharset); + _state = 21; + return true; + } + + state21: + if ((_bits & 2097152L) != 0) + { + _current = new KeyValuePair("Accept-Encoding", _collection._headers._AcceptEncoding); + _state = 22; + return true; + } + + state22: + if ((_bits & 4194304L) != 0) + { + _current = new KeyValuePair("Accept-Language", _collection._headers._AcceptLanguage); + _state = 23; + return true; + } + + state23: + if ((_bits & 8388608L) != 0) + { + _current = new KeyValuePair("Authorization", _collection._headers._Authorization); + _state = 24; + return true; + } + + state24: + if ((_bits & 16777216L) != 0) + { + _current = new KeyValuePair("Cookie", _collection._headers._Cookie); + _state = 25; + return true; + } + + state25: + if ((_bits & 33554432L) != 0) + { + _current = new KeyValuePair("Expect", _collection._headers._Expect); + _state = 26; + return true; + } + + state26: + if ((_bits & 67108864L) != 0) + { + _current = new KeyValuePair("From", _collection._headers._From); + _state = 27; + return true; + } + + state27: + if ((_bits & 134217728L) != 0) + { + _current = new KeyValuePair("Host", _collection._headers._Host); + _state = 28; + return true; + } + + state28: + if ((_bits & 268435456L) != 0) + { + _current = new KeyValuePair("If-Match", _collection._headers._IfMatch); + _state = 29; + return true; + } + + state29: + if ((_bits & 536870912L) != 0) + { + _current = new KeyValuePair("If-Modified-Since", _collection._headers._IfModifiedSince); + _state = 30; + return true; + } + + state30: + if ((_bits & 1073741824L) != 0) + { + _current = new KeyValuePair("If-None-Match", _collection._headers._IfNoneMatch); + _state = 31; + return true; + } + + state31: + if ((_bits & 2147483648L) != 0) + { + _current = new KeyValuePair("If-Range", _collection._headers._IfRange); + _state = 32; + return true; + } + + state32: + if ((_bits & 4294967296L) != 0) + { + _current = new KeyValuePair("If-Unmodified-Since", _collection._headers._IfUnmodifiedSince); + _state = 33; + return true; + } + + state33: + if ((_bits & 8589934592L) != 0) + { + _current = new KeyValuePair("Max-Forwards", _collection._headers._MaxForwards); + _state = 34; + return true; + } + + state34: + if ((_bits & 17179869184L) != 0) + { + _current = new KeyValuePair("Proxy-Authorization", _collection._headers._ProxyAuthorization); + _state = 35; + return true; + } + + state35: + if ((_bits & 34359738368L) != 0) + { + _current = new KeyValuePair("Referer", _collection._headers._Referer); + _state = 36; + return true; + } + + state36: + if ((_bits & 68719476736L) != 0) + { + _current = new KeyValuePair("Range", _collection._headers._Range); + _state = 37; + return true; + } + + state37: + if ((_bits & 137438953472L) != 0) + { + _current = new KeyValuePair("TE", _collection._headers._TE); + _state = 38; + return true; + } + + state38: + if ((_bits & 274877906944L) != 0) + { + _current = new KeyValuePair("Translate", _collection._headers._Translate); + _state = 39; + return true; + } + + state39: + if ((_bits & 549755813888L) != 0) + { + _current = new KeyValuePair("User-Agent", _collection._headers._UserAgent); + _state = 40; + return true; + } + + state40: + if ((_bits & 1099511627776L) != 0) + { + _current = new KeyValuePair("Origin", _collection._headers._Origin); + _state = 41; + return true; + } + + state41: + if ((_bits & 2199023255552L) != 0) + { + _current = new KeyValuePair("Access-Control-Request-Method", _collection._headers._AccessControlRequestMethod); + _state = 42; + return true; + } + + state42: + if ((_bits & 4398046511104L) != 0) + { + _current = new KeyValuePair("Access-Control-Request-Headers", _collection._headers._AccessControlRequestHeaders); + _state = 43; + return true; + } + + state44: + if (_collection._contentLength.HasValue) + { + _current = new KeyValuePair("Content-Length", HeaderUtilities.FormatNonNegativeInt64(_collection._contentLength.Value)); + _state = 45; + return true; + } + state_default: + if (!_hasUnknown || !_unknownEnumerator.MoveNext()) + { + _current = default(KeyValuePair); + return false; + } + _current = _unknownEnumerator.Current; + return true; + } + } + } + + public partial class HttpResponseHeaders + { + private static byte[] _headerBytes = new byte[] + { + 13,10,67,97,99,104,101,45,67,111,110,116,114,111,108,58,32,13,10,67,111,110,110,101,99,116,105,111,110,58,32,13,10,68,97,116,101,58,32,13,10,75,101,101,112,45,65,108,105,118,101,58,32,13,10,80,114,97,103,109,97,58,32,13,10,84,114,97,105,108,101,114,58,32,13,10,84,114,97,110,115,102,101,114,45,69,110,99,111,100,105,110,103,58,32,13,10,85,112,103,114,97,100,101,58,32,13,10,86,105,97,58,32,13,10,87,97,114,110,105,110,103,58,32,13,10,65,108,108,111,119,58,32,13,10,67,111,110,116,101,110,116,45,84,121,112,101,58,32,13,10,67,111,110,116,101,110,116,45,69,110,99,111,100,105,110,103,58,32,13,10,67,111,110,116,101,110,116,45,76,97,110,103,117,97,103,101,58,32,13,10,67,111,110,116,101,110,116,45,76,111,99,97,116,105,111,110,58,32,13,10,67,111,110,116,101,110,116,45,77,68,53,58,32,13,10,67,111,110,116,101,110,116,45,82,97,110,103,101,58,32,13,10,69,120,112,105,114,101,115,58,32,13,10,76,97,115,116,45,77,111,100,105,102,105,101,100,58,32,13,10,65,99,99,101,112,116,45,82,97,110,103,101,115,58,32,13,10,65,103,101,58,32,13,10,69,84,97,103,58,32,13,10,76,111,99,97,116,105,111,110,58,32,13,10,80,114,111,120,121,45,65,117,116,104,101,110,116,105,99,97,116,101,58,32,13,10,82,101,116,114,121,45,65,102,116,101,114,58,32,13,10,83,101,114,118,101,114,58,32,13,10,83,101,116,45,67,111,111,107,105,101,58,32,13,10,86,97,114,121,58,32,13,10,87,87,87,45,65,117,116,104,101,110,116,105,99,97,116,101,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,65,108,108,111,119,45,67,114,101,100,101,110,116,105,97,108,115,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,65,108,108,111,119,45,72,101,97,100,101,114,115,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,65,108,108,111,119,45,77,101,116,104,111,100,115,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,65,108,108,111,119,45,79,114,105,103,105,110,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,69,120,112,111,115,101,45,72,101,97,100,101,114,115,58,32,13,10,65,99,99,101,115,115,45,67,111,110,116,114,111,108,45,77,97,120,45,65,103,101,58,32,13,10,67,111,110,116,101,110,116,45,76,101,110,103,116,104,58,32, + }; + + private long _bits = 0; + private HeaderReferences _headers; + + public bool HasConnection => (_bits & 2L) != 0; + public bool HasDate => (_bits & 4L) != 0; + public bool HasTransferEncoding => (_bits & 64L) != 0; + public bool HasServer => (_bits & 33554432L) != 0; + + + public StringValues HeaderCacheControl + { + get + { + StringValues value; + if ((_bits & 1L) != 0) + { + value = _headers._CacheControl; + } + return value; + } + set + { + _bits |= 1L; + _headers._CacheControl = value; + } + } + public StringValues HeaderConnection + { + get + { + StringValues value; + if ((_bits & 2L) != 0) + { + value = _headers._Connection; + } + return value; + } + set + { + _bits |= 2L; + _headers._Connection = value; + _headers._rawConnection = null; + } + } + public StringValues HeaderDate + { + get + { + StringValues value; + if ((_bits & 4L) != 0) + { + value = _headers._Date; + } + return value; + } + set + { + _bits |= 4L; + _headers._Date = value; + _headers._rawDate = null; + } + } + public StringValues HeaderKeepAlive + { + get + { + StringValues value; + if ((_bits & 8L) != 0) + { + value = _headers._KeepAlive; + } + return value; + } + set + { + _bits |= 8L; + _headers._KeepAlive = value; + } + } + public StringValues HeaderPragma + { + get + { + StringValues value; + if ((_bits & 16L) != 0) + { + value = _headers._Pragma; + } + return value; + } + set + { + _bits |= 16L; + _headers._Pragma = value; + } + } + public StringValues HeaderTrailer + { + get + { + StringValues value; + if ((_bits & 32L) != 0) + { + value = _headers._Trailer; + } + return value; + } + set + { + _bits |= 32L; + _headers._Trailer = value; + } + } + public StringValues HeaderTransferEncoding + { + get + { + StringValues value; + if ((_bits & 64L) != 0) + { + value = _headers._TransferEncoding; + } + return value; + } + set + { + _bits |= 64L; + _headers._TransferEncoding = value; + _headers._rawTransferEncoding = null; + } + } + public StringValues HeaderUpgrade + { + get + { + StringValues value; + if ((_bits & 128L) != 0) + { + value = _headers._Upgrade; + } + return value; + } + set + { + _bits |= 128L; + _headers._Upgrade = value; + } + } + public StringValues HeaderVia + { + get + { + StringValues value; + if ((_bits & 256L) != 0) + { + value = _headers._Via; + } + return value; + } + set + { + _bits |= 256L; + _headers._Via = value; + } + } + public StringValues HeaderWarning + { + get + { + StringValues value; + if ((_bits & 512L) != 0) + { + value = _headers._Warning; + } + return value; + } + set + { + _bits |= 512L; + _headers._Warning = value; + } + } + public StringValues HeaderAllow + { + get + { + StringValues value; + if ((_bits & 1024L) != 0) + { + value = _headers._Allow; + } + return value; + } + set + { + _bits |= 1024L; + _headers._Allow = value; + } + } + public StringValues HeaderContentType + { + get + { + StringValues value; + if ((_bits & 2048L) != 0) + { + value = _headers._ContentType; + } + return value; + } + set + { + _bits |= 2048L; + _headers._ContentType = value; + } + } + public StringValues HeaderContentEncoding + { + get + { + StringValues value; + if ((_bits & 4096L) != 0) + { + value = _headers._ContentEncoding; + } + return value; + } + set + { + _bits |= 4096L; + _headers._ContentEncoding = value; + } + } + public StringValues HeaderContentLanguage + { + get + { + StringValues value; + if ((_bits & 8192L) != 0) + { + value = _headers._ContentLanguage; + } + return value; + } + set + { + _bits |= 8192L; + _headers._ContentLanguage = value; + } + } + public StringValues HeaderContentLocation + { + get + { + StringValues value; + if ((_bits & 16384L) != 0) + { + value = _headers._ContentLocation; + } + return value; + } + set + { + _bits |= 16384L; + _headers._ContentLocation = value; + } + } + public StringValues HeaderContentMD5 + { + get + { + StringValues value; + if ((_bits & 32768L) != 0) + { + value = _headers._ContentMD5; + } + return value; + } + set + { + _bits |= 32768L; + _headers._ContentMD5 = value; + } + } + public StringValues HeaderContentRange + { + get + { + StringValues value; + if ((_bits & 65536L) != 0) + { + value = _headers._ContentRange; + } + return value; + } + set + { + _bits |= 65536L; + _headers._ContentRange = value; + } + } + public StringValues HeaderExpires + { + get + { + StringValues value; + if ((_bits & 131072L) != 0) + { + value = _headers._Expires; + } + return value; + } + set + { + _bits |= 131072L; + _headers._Expires = value; + } + } + public StringValues HeaderLastModified + { + get + { + StringValues value; + if ((_bits & 262144L) != 0) + { + value = _headers._LastModified; + } + return value; + } + set + { + _bits |= 262144L; + _headers._LastModified = value; + } + } + public StringValues HeaderAcceptRanges + { + get + { + StringValues value; + if ((_bits & 524288L) != 0) + { + value = _headers._AcceptRanges; + } + return value; + } + set + { + _bits |= 524288L; + _headers._AcceptRanges = value; + } + } + public StringValues HeaderAge + { + get + { + StringValues value; + if ((_bits & 1048576L) != 0) + { + value = _headers._Age; + } + return value; + } + set + { + _bits |= 1048576L; + _headers._Age = value; + } + } + public StringValues HeaderETag + { + get + { + StringValues value; + if ((_bits & 2097152L) != 0) + { + value = _headers._ETag; + } + return value; + } + set + { + _bits |= 2097152L; + _headers._ETag = value; + } + } + public StringValues HeaderLocation + { + get + { + StringValues value; + if ((_bits & 4194304L) != 0) + { + value = _headers._Location; + } + return value; + } + set + { + _bits |= 4194304L; + _headers._Location = value; + } + } + public StringValues HeaderProxyAuthenticate + { + get + { + StringValues value; + if ((_bits & 8388608L) != 0) + { + value = _headers._ProxyAuthenticate; + } + return value; + } + set + { + _bits |= 8388608L; + _headers._ProxyAuthenticate = value; + } + } + public StringValues HeaderRetryAfter + { + get + { + StringValues value; + if ((_bits & 16777216L) != 0) + { + value = _headers._RetryAfter; + } + return value; + } + set + { + _bits |= 16777216L; + _headers._RetryAfter = value; + } + } + public StringValues HeaderServer + { + get + { + StringValues value; + if ((_bits & 33554432L) != 0) + { + value = _headers._Server; + } + return value; + } + set + { + _bits |= 33554432L; + _headers._Server = value; + _headers._rawServer = null; + } + } + public StringValues HeaderSetCookie + { + get + { + StringValues value; + if ((_bits & 67108864L) != 0) + { + value = _headers._SetCookie; + } + return value; + } + set + { + _bits |= 67108864L; + _headers._SetCookie = value; + } + } + public StringValues HeaderVary + { + get + { + StringValues value; + if ((_bits & 134217728L) != 0) + { + value = _headers._Vary; + } + return value; + } + set + { + _bits |= 134217728L; + _headers._Vary = value; + } + } + public StringValues HeaderWWWAuthenticate + { + get + { + StringValues value; + if ((_bits & 268435456L) != 0) + { + value = _headers._WWWAuthenticate; + } + return value; + } + set + { + _bits |= 268435456L; + _headers._WWWAuthenticate = value; + } + } + public StringValues HeaderAccessControlAllowCredentials + { + get + { + StringValues value; + if ((_bits & 536870912L) != 0) + { + value = _headers._AccessControlAllowCredentials; + } + return value; + } + set + { + _bits |= 536870912L; + _headers._AccessControlAllowCredentials = value; + } + } + public StringValues HeaderAccessControlAllowHeaders + { + get + { + StringValues value; + if ((_bits & 1073741824L) != 0) + { + value = _headers._AccessControlAllowHeaders; + } + return value; + } + set + { + _bits |= 1073741824L; + _headers._AccessControlAllowHeaders = value; + } + } + public StringValues HeaderAccessControlAllowMethods + { + get + { + StringValues value; + if ((_bits & 2147483648L) != 0) + { + value = _headers._AccessControlAllowMethods; + } + return value; + } + set + { + _bits |= 2147483648L; + _headers._AccessControlAllowMethods = value; + } + } + public StringValues HeaderAccessControlAllowOrigin + { + get + { + StringValues value; + if ((_bits & 4294967296L) != 0) + { + value = _headers._AccessControlAllowOrigin; + } + return value; + } + set + { + _bits |= 4294967296L; + _headers._AccessControlAllowOrigin = value; + } + } + public StringValues HeaderAccessControlExposeHeaders + { + get + { + StringValues value; + if ((_bits & 8589934592L) != 0) + { + value = _headers._AccessControlExposeHeaders; + } + return value; + } + set + { + _bits |= 8589934592L; + _headers._AccessControlExposeHeaders = value; + } + } + public StringValues HeaderAccessControlMaxAge + { + get + { + StringValues value; + if ((_bits & 17179869184L) != 0) + { + value = _headers._AccessControlMaxAge; + } + return value; + } + set + { + _bits |= 17179869184L; + _headers._AccessControlMaxAge = value; + } + } + public StringValues HeaderContentLength + { + get + { + StringValues value; + if (_contentLength.HasValue) + { + value = new StringValues(HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + } + return value; + } + set + { + _contentLength = ParseContentLength(value); + } + } + + public void SetRawConnection(in StringValues value, byte[] raw) + { + _bits |= 2L; + _headers._Connection = value; + _headers._rawConnection = raw; + } + public void SetRawDate(in StringValues value, byte[] raw) + { + _bits |= 4L; + _headers._Date = value; + _headers._rawDate = raw; + } + public void SetRawTransferEncoding(in StringValues value, byte[] raw) + { + _bits |= 64L; + _headers._TransferEncoding = value; + _headers._rawTransferEncoding = raw; + } + public void SetRawServer(in StringValues value, byte[] raw) + { + _bits |= 33554432L; + _headers._Server = value; + _headers._rawServer = raw; + } + protected override int GetCountFast() + { + return (_contentLength.HasValue ? 1 : 0 ) + BitCount(_bits) + (MaybeUnknown?.Count ?? 0); + } + + protected override bool TryGetValueFast(string key, out StringValues value) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) != 0) + { + value = _headers._CacheControl; + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) != 0) + { + value = _headers._ContentRange; + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) != 0) + { + value = _headers._LastModified; + return true; + } + return false; + } + if ("Accept-Ranges".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) != 0) + { + value = _headers._AcceptRanges; + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) != 0) + { + value = _headers._Connection; + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) != 0) + { + value = _headers._KeepAlive; + return true; + } + return false; + } + if ("Set-Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) != 0) + { + value = _headers._SetCookie; + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) != 0) + { + value = _headers._Date; + return true; + } + return false; + } + if ("ETag".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) != 0) + { + value = _headers._ETag; + return true; + } + return false; + } + if ("Vary".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) != 0) + { + value = _headers._Vary; + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) != 0) + { + value = _headers._Pragma; + return true; + } + return false; + } + if ("Server".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) != 0) + { + value = _headers._Server; + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) != 0) + { + value = _headers._Trailer; + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) != 0) + { + value = _headers._Upgrade; + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) != 0) + { + value = _headers._Warning; + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) != 0) + { + value = _headers._Expires; + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) != 0) + { + value = _headers._TransferEncoding; + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) != 0) + { + value = _headers._Via; + return true; + } + return false; + } + if ("Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) != 0) + { + value = _headers._Age; + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) != 0) + { + value = _headers._Allow; + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) != 0) + { + value = _headers._ContentType; + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) != 0) + { + value = _headers._ContentEncoding; + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) != 0) + { + value = _headers._ContentLanguage; + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) != 0) + { + value = _headers._ContentLocation; + return true; + } + return false; + } + if ("WWW-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) != 0) + { + value = _headers._WWWAuthenticate; + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) != 0) + { + value = _headers._ContentMD5; + return true; + } + return false; + } + if ("Retry-After".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) != 0) + { + value = _headers._RetryAfter; + return true; + } + return false; + } + } + break; + case 8: + { + if ("Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) != 0) + { + value = _headers._Location; + return true; + } + return false; + } + } + break; + case 18: + { + if ("Proxy-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) != 0) + { + value = _headers._ProxyAuthenticate; + return true; + } + return false; + } + } + break; + case 32: + { + if ("Access-Control-Allow-Credentials".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) != 0) + { + value = _headers._AccessControlAllowCredentials; + return true; + } + return false; + } + } + break; + case 28: + { + if ("Access-Control-Allow-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) != 0) + { + value = _headers._AccessControlAllowHeaders; + return true; + } + return false; + } + if ("Access-Control-Allow-Methods".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) != 0) + { + value = _headers._AccessControlAllowMethods; + return true; + } + return false; + } + } + break; + case 27: + { + if ("Access-Control-Allow-Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) != 0) + { + value = _headers._AccessControlAllowOrigin; + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Expose-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) != 0) + { + value = _headers._AccessControlExposeHeaders; + return true; + } + return false; + } + } + break; + case 22: + { + if ("Access-Control-Max-Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) != 0) + { + value = _headers._AccessControlMaxAge; + return true; + } + return false; + } + } + break; + case 14: + { + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (_contentLength.HasValue) + { + value = HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value); + return true; + } + return false; + } + } + break; + } + + return MaybeUnknown?.TryGetValue(key, out value) ?? false; + } + + protected override void SetValueFast(string key, in StringValues value) + { + ValidateHeaderCharacters(value); + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1L; + _headers._CacheControl = value; + return; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 65536L; + _headers._ContentRange = value; + return; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 262144L; + _headers._LastModified = value; + return; + } + if ("Accept-Ranges".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 524288L; + _headers._AcceptRanges = value; + return; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2L; + _headers._Connection = value; + _headers._rawConnection = null; + return; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8L; + _headers._KeepAlive = value; + return; + } + if ("Set-Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 67108864L; + _headers._SetCookie = value; + return; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4L; + _headers._Date = value; + _headers._rawDate = null; + return; + } + if ("ETag".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2097152L; + _headers._ETag = value; + return; + } + if ("Vary".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 134217728L; + _headers._Vary = value; + return; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16L; + _headers._Pragma = value; + return; + } + if ("Server".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 33554432L; + _headers._Server = value; + _headers._rawServer = null; + return; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 32L; + _headers._Trailer = value; + return; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 128L; + _headers._Upgrade = value; + return; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 512L; + _headers._Warning = value; + return; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 131072L; + _headers._Expires = value; + return; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 64L; + _headers._TransferEncoding = value; + _headers._rawTransferEncoding = null; + return; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 256L; + _headers._Via = value; + return; + } + if ("Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1048576L; + _headers._Age = value; + return; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1024L; + _headers._Allow = value; + return; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2048L; + _headers._ContentType = value; + return; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4096L; + _headers._ContentEncoding = value; + return; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8192L; + _headers._ContentLanguage = value; + return; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16384L; + _headers._ContentLocation = value; + return; + } + if ("WWW-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 268435456L; + _headers._WWWAuthenticate = value; + return; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 32768L; + _headers._ContentMD5 = value; + return; + } + if ("Retry-After".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 16777216L; + _headers._RetryAfter = value; + return; + } + } + break; + case 8: + { + if ("Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4194304L; + _headers._Location = value; + return; + } + } + break; + case 18: + { + if ("Proxy-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8388608L; + _headers._ProxyAuthenticate = value; + return; + } + } + break; + case 32: + { + if ("Access-Control-Allow-Credentials".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 536870912L; + _headers._AccessControlAllowCredentials = value; + return; + } + } + break; + case 28: + { + if ("Access-Control-Allow-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 1073741824L; + _headers._AccessControlAllowHeaders = value; + return; + } + if ("Access-Control-Allow-Methods".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 2147483648L; + _headers._AccessControlAllowMethods = value; + return; + } + } + break; + case 27: + { + if ("Access-Control-Allow-Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 4294967296L; + _headers._AccessControlAllowOrigin = value; + return; + } + } + break; + case 29: + { + if ("Access-Control-Expose-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 8589934592L; + _headers._AccessControlExposeHeaders = value; + return; + } + } + break; + case 22: + { + if ("Access-Control-Max-Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _bits |= 17179869184L; + _headers._AccessControlMaxAge = value; + return; + } + } + break; + case 14: + { + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + _contentLength = ParseContentLength(value.ToString()); + return; + } + } + break; + } + + SetValueUnknown(key, value); + } + + protected override bool AddValueFast(string key, in StringValues value) + { + ValidateHeaderCharacters(value); + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) == 0) + { + _bits |= 1L; + _headers._CacheControl = value; + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) == 0) + { + _bits |= 65536L; + _headers._ContentRange = value; + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) == 0) + { + _bits |= 262144L; + _headers._LastModified = value; + return true; + } + return false; + } + if ("Accept-Ranges".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) == 0) + { + _bits |= 524288L; + _headers._AcceptRanges = value; + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) == 0) + { + _bits |= 2L; + _headers._Connection = value; + _headers._rawConnection = null; + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) == 0) + { + _bits |= 8L; + _headers._KeepAlive = value; + return true; + } + return false; + } + if ("Set-Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) == 0) + { + _bits |= 67108864L; + _headers._SetCookie = value; + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) == 0) + { + _bits |= 4L; + _headers._Date = value; + _headers._rawDate = null; + return true; + } + return false; + } + if ("ETag".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) == 0) + { + _bits |= 2097152L; + _headers._ETag = value; + return true; + } + return false; + } + if ("Vary".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) == 0) + { + _bits |= 134217728L; + _headers._Vary = value; + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) == 0) + { + _bits |= 16L; + _headers._Pragma = value; + return true; + } + return false; + } + if ("Server".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) == 0) + { + _bits |= 33554432L; + _headers._Server = value; + _headers._rawServer = null; + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) == 0) + { + _bits |= 32L; + _headers._Trailer = value; + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) == 0) + { + _bits |= 128L; + _headers._Upgrade = value; + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) == 0) + { + _bits |= 512L; + _headers._Warning = value; + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) == 0) + { + _bits |= 131072L; + _headers._Expires = value; + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) == 0) + { + _bits |= 64L; + _headers._TransferEncoding = value; + _headers._rawTransferEncoding = null; + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) == 0) + { + _bits |= 256L; + _headers._Via = value; + return true; + } + return false; + } + if ("Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) == 0) + { + _bits |= 1048576L; + _headers._Age = value; + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) == 0) + { + _bits |= 1024L; + _headers._Allow = value; + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) == 0) + { + _bits |= 2048L; + _headers._ContentType = value; + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) == 0) + { + _bits |= 4096L; + _headers._ContentEncoding = value; + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) == 0) + { + _bits |= 8192L; + _headers._ContentLanguage = value; + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) == 0) + { + _bits |= 16384L; + _headers._ContentLocation = value; + return true; + } + return false; + } + if ("WWW-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) == 0) + { + _bits |= 268435456L; + _headers._WWWAuthenticate = value; + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) == 0) + { + _bits |= 32768L; + _headers._ContentMD5 = value; + return true; + } + return false; + } + if ("Retry-After".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) == 0) + { + _bits |= 16777216L; + _headers._RetryAfter = value; + return true; + } + return false; + } + } + break; + case 8: + { + if ("Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) == 0) + { + _bits |= 4194304L; + _headers._Location = value; + return true; + } + return false; + } + } + break; + case 18: + { + if ("Proxy-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) == 0) + { + _bits |= 8388608L; + _headers._ProxyAuthenticate = value; + return true; + } + return false; + } + } + break; + case 32: + { + if ("Access-Control-Allow-Credentials".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) == 0) + { + _bits |= 536870912L; + _headers._AccessControlAllowCredentials = value; + return true; + } + return false; + } + } + break; + case 28: + { + if ("Access-Control-Allow-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) == 0) + { + _bits |= 1073741824L; + _headers._AccessControlAllowHeaders = value; + return true; + } + return false; + } + if ("Access-Control-Allow-Methods".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) == 0) + { + _bits |= 2147483648L; + _headers._AccessControlAllowMethods = value; + return true; + } + return false; + } + } + break; + case 27: + { + if ("Access-Control-Allow-Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) == 0) + { + _bits |= 4294967296L; + _headers._AccessControlAllowOrigin = value; + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Expose-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) == 0) + { + _bits |= 8589934592L; + _headers._AccessControlExposeHeaders = value; + return true; + } + return false; + } + } + break; + case 22: + { + if ("Access-Control-Max-Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) == 0) + { + _bits |= 17179869184L; + _headers._AccessControlMaxAge = value; + return true; + } + return false; + } + } + break; + case 14: + { + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (!_contentLength.HasValue) + { + _contentLength = ParseContentLength(value); + return true; + } + return false; + } + } + break; + } + + ValidateHeaderCharacters(key); + Unknown.Add(key, value); + // Return true, above will throw and exit for false + return true; + } + + protected override bool RemoveFast(string key) + { + switch (key.Length) + { + case 13: + { + if ("Cache-Control".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1L) != 0) + { + _bits &= ~1L; + _headers._CacheControl = default(StringValues); + return true; + } + return false; + } + if ("Content-Range".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 65536L) != 0) + { + _bits &= ~65536L; + _headers._ContentRange = default(StringValues); + return true; + } + return false; + } + if ("Last-Modified".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 262144L) != 0) + { + _bits &= ~262144L; + _headers._LastModified = default(StringValues); + return true; + } + return false; + } + if ("Accept-Ranges".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 524288L) != 0) + { + _bits &= ~524288L; + _headers._AcceptRanges = default(StringValues); + return true; + } + return false; + } + } + break; + case 10: + { + if ("Connection".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2L) != 0) + { + _bits &= ~2L; + _headers._Connection = default(StringValues); + _headers._rawConnection = null; + return true; + } + return false; + } + if ("Keep-Alive".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8L) != 0) + { + _bits &= ~8L; + _headers._KeepAlive = default(StringValues); + return true; + } + return false; + } + if ("Set-Cookie".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 67108864L) != 0) + { + _bits &= ~67108864L; + _headers._SetCookie = default(StringValues); + return true; + } + return false; + } + } + break; + case 4: + { + if ("Date".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4L) != 0) + { + _bits &= ~4L; + _headers._Date = default(StringValues); + _headers._rawDate = null; + return true; + } + return false; + } + if ("ETag".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2097152L) != 0) + { + _bits &= ~2097152L; + _headers._ETag = default(StringValues); + return true; + } + return false; + } + if ("Vary".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 134217728L) != 0) + { + _bits &= ~134217728L; + _headers._Vary = default(StringValues); + return true; + } + return false; + } + } + break; + case 6: + { + if ("Pragma".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16L) != 0) + { + _bits &= ~16L; + _headers._Pragma = default(StringValues); + return true; + } + return false; + } + if ("Server".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 33554432L) != 0) + { + _bits &= ~33554432L; + _headers._Server = default(StringValues); + _headers._rawServer = null; + return true; + } + return false; + } + } + break; + case 7: + { + if ("Trailer".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32L) != 0) + { + _bits &= ~32L; + _headers._Trailer = default(StringValues); + return true; + } + return false; + } + if ("Upgrade".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 128L) != 0) + { + _bits &= ~128L; + _headers._Upgrade = default(StringValues); + return true; + } + return false; + } + if ("Warning".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 512L) != 0) + { + _bits &= ~512L; + _headers._Warning = default(StringValues); + return true; + } + return false; + } + if ("Expires".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 131072L) != 0) + { + _bits &= ~131072L; + _headers._Expires = default(StringValues); + return true; + } + return false; + } + } + break; + case 17: + { + if ("Transfer-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 64L) != 0) + { + _bits &= ~64L; + _headers._TransferEncoding = default(StringValues); + _headers._rawTransferEncoding = null; + return true; + } + return false; + } + } + break; + case 3: + { + if ("Via".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 256L) != 0) + { + _bits &= ~256L; + _headers._Via = default(StringValues); + return true; + } + return false; + } + if ("Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1048576L) != 0) + { + _bits &= ~1048576L; + _headers._Age = default(StringValues); + return true; + } + return false; + } + } + break; + case 5: + { + if ("Allow".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1024L) != 0) + { + _bits &= ~1024L; + _headers._Allow = default(StringValues); + return true; + } + return false; + } + } + break; + case 12: + { + if ("Content-Type".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2048L) != 0) + { + _bits &= ~2048L; + _headers._ContentType = default(StringValues); + return true; + } + return false; + } + } + break; + case 16: + { + if ("Content-Encoding".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4096L) != 0) + { + _bits &= ~4096L; + _headers._ContentEncoding = default(StringValues); + return true; + } + return false; + } + if ("Content-Language".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8192L) != 0) + { + _bits &= ~8192L; + _headers._ContentLanguage = default(StringValues); + return true; + } + return false; + } + if ("Content-Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16384L) != 0) + { + _bits &= ~16384L; + _headers._ContentLocation = default(StringValues); + return true; + } + return false; + } + if ("WWW-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 268435456L) != 0) + { + _bits &= ~268435456L; + _headers._WWWAuthenticate = default(StringValues); + return true; + } + return false; + } + } + break; + case 11: + { + if ("Content-MD5".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 32768L) != 0) + { + _bits &= ~32768L; + _headers._ContentMD5 = default(StringValues); + return true; + } + return false; + } + if ("Retry-After".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 16777216L) != 0) + { + _bits &= ~16777216L; + _headers._RetryAfter = default(StringValues); + return true; + } + return false; + } + } + break; + case 8: + { + if ("Location".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4194304L) != 0) + { + _bits &= ~4194304L; + _headers._Location = default(StringValues); + return true; + } + return false; + } + } + break; + case 18: + { + if ("Proxy-Authenticate".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8388608L) != 0) + { + _bits &= ~8388608L; + _headers._ProxyAuthenticate = default(StringValues); + return true; + } + return false; + } + } + break; + case 32: + { + if ("Access-Control-Allow-Credentials".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 536870912L) != 0) + { + _bits &= ~536870912L; + _headers._AccessControlAllowCredentials = default(StringValues); + return true; + } + return false; + } + } + break; + case 28: + { + if ("Access-Control-Allow-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 1073741824L) != 0) + { + _bits &= ~1073741824L; + _headers._AccessControlAllowHeaders = default(StringValues); + return true; + } + return false; + } + if ("Access-Control-Allow-Methods".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 2147483648L) != 0) + { + _bits &= ~2147483648L; + _headers._AccessControlAllowMethods = default(StringValues); + return true; + } + return false; + } + } + break; + case 27: + { + if ("Access-Control-Allow-Origin".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 4294967296L) != 0) + { + _bits &= ~4294967296L; + _headers._AccessControlAllowOrigin = default(StringValues); + return true; + } + return false; + } + } + break; + case 29: + { + if ("Access-Control-Expose-Headers".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 8589934592L) != 0) + { + _bits &= ~8589934592L; + _headers._AccessControlExposeHeaders = default(StringValues); + return true; + } + return false; + } + } + break; + case 22: + { + if ("Access-Control-Max-Age".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if ((_bits & 17179869184L) != 0) + { + _bits &= ~17179869184L; + _headers._AccessControlMaxAge = default(StringValues); + return true; + } + return false; + } + } + break; + case 14: + { + if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) + { + if (_contentLength.HasValue) + { + _contentLength = null; + return true; + } + return false; + } + } + break; + } + + return MaybeUnknown?.Remove(key) ?? false; + } + + protected override void ClearFast() + { + MaybeUnknown?.Clear(); + _contentLength = null; + var tempBits = _bits; + _bits = 0; + if(HttpHeaders.BitCount(tempBits) > 12) + { + _headers = default(HeaderReferences); + return; + } + + if ((tempBits & 2L) != 0) + { + _headers._Connection = default(StringValues); + if((tempBits & ~2L) == 0) + { + return; + } + tempBits &= ~2L; + } + + if ((tempBits & 4L) != 0) + { + _headers._Date = default(StringValues); + if((tempBits & ~4L) == 0) + { + return; + } + tempBits &= ~4L; + } + + if ((tempBits & 2048L) != 0) + { + _headers._ContentType = default(StringValues); + if((tempBits & ~2048L) == 0) + { + return; + } + tempBits &= ~2048L; + } + + if ((tempBits & 33554432L) != 0) + { + _headers._Server = default(StringValues); + if((tempBits & ~33554432L) == 0) + { + return; + } + tempBits &= ~33554432L; + } + + if ((tempBits & 1L) != 0) + { + _headers._CacheControl = default(StringValues); + if((tempBits & ~1L) == 0) + { + return; + } + tempBits &= ~1L; + } + + if ((tempBits & 8L) != 0) + { + _headers._KeepAlive = default(StringValues); + if((tempBits & ~8L) == 0) + { + return; + } + tempBits &= ~8L; + } + + if ((tempBits & 16L) != 0) + { + _headers._Pragma = default(StringValues); + if((tempBits & ~16L) == 0) + { + return; + } + tempBits &= ~16L; + } + + if ((tempBits & 32L) != 0) + { + _headers._Trailer = default(StringValues); + if((tempBits & ~32L) == 0) + { + return; + } + tempBits &= ~32L; + } + + if ((tempBits & 64L) != 0) + { + _headers._TransferEncoding = default(StringValues); + if((tempBits & ~64L) == 0) + { + return; + } + tempBits &= ~64L; + } + + if ((tempBits & 128L) != 0) + { + _headers._Upgrade = default(StringValues); + if((tempBits & ~128L) == 0) + { + return; + } + tempBits &= ~128L; + } + + if ((tempBits & 256L) != 0) + { + _headers._Via = default(StringValues); + if((tempBits & ~256L) == 0) + { + return; + } + tempBits &= ~256L; + } + + if ((tempBits & 512L) != 0) + { + _headers._Warning = default(StringValues); + if((tempBits & ~512L) == 0) + { + return; + } + tempBits &= ~512L; + } + + if ((tempBits & 1024L) != 0) + { + _headers._Allow = default(StringValues); + if((tempBits & ~1024L) == 0) + { + return; + } + tempBits &= ~1024L; + } + + if ((tempBits & 4096L) != 0) + { + _headers._ContentEncoding = default(StringValues); + if((tempBits & ~4096L) == 0) + { + return; + } + tempBits &= ~4096L; + } + + if ((tempBits & 8192L) != 0) + { + _headers._ContentLanguage = default(StringValues); + if((tempBits & ~8192L) == 0) + { + return; + } + tempBits &= ~8192L; + } + + if ((tempBits & 16384L) != 0) + { + _headers._ContentLocation = default(StringValues); + if((tempBits & ~16384L) == 0) + { + return; + } + tempBits &= ~16384L; + } + + if ((tempBits & 32768L) != 0) + { + _headers._ContentMD5 = default(StringValues); + if((tempBits & ~32768L) == 0) + { + return; + } + tempBits &= ~32768L; + } + + if ((tempBits & 65536L) != 0) + { + _headers._ContentRange = default(StringValues); + if((tempBits & ~65536L) == 0) + { + return; + } + tempBits &= ~65536L; + } + + if ((tempBits & 131072L) != 0) + { + _headers._Expires = default(StringValues); + if((tempBits & ~131072L) == 0) + { + return; + } + tempBits &= ~131072L; + } + + if ((tempBits & 262144L) != 0) + { + _headers._LastModified = default(StringValues); + if((tempBits & ~262144L) == 0) + { + return; + } + tempBits &= ~262144L; + } + + if ((tempBits & 524288L) != 0) + { + _headers._AcceptRanges = default(StringValues); + if((tempBits & ~524288L) == 0) + { + return; + } + tempBits &= ~524288L; + } + + if ((tempBits & 1048576L) != 0) + { + _headers._Age = default(StringValues); + if((tempBits & ~1048576L) == 0) + { + return; + } + tempBits &= ~1048576L; + } + + if ((tempBits & 2097152L) != 0) + { + _headers._ETag = default(StringValues); + if((tempBits & ~2097152L) == 0) + { + return; + } + tempBits &= ~2097152L; + } + + if ((tempBits & 4194304L) != 0) + { + _headers._Location = default(StringValues); + if((tempBits & ~4194304L) == 0) + { + return; + } + tempBits &= ~4194304L; + } + + if ((tempBits & 8388608L) != 0) + { + _headers._ProxyAuthenticate = default(StringValues); + if((tempBits & ~8388608L) == 0) + { + return; + } + tempBits &= ~8388608L; + } + + if ((tempBits & 16777216L) != 0) + { + _headers._RetryAfter = default(StringValues); + if((tempBits & ~16777216L) == 0) + { + return; + } + tempBits &= ~16777216L; + } + + if ((tempBits & 67108864L) != 0) + { + _headers._SetCookie = default(StringValues); + if((tempBits & ~67108864L) == 0) + { + return; + } + tempBits &= ~67108864L; + } + + if ((tempBits & 134217728L) != 0) + { + _headers._Vary = default(StringValues); + if((tempBits & ~134217728L) == 0) + { + return; + } + tempBits &= ~134217728L; + } + + if ((tempBits & 268435456L) != 0) + { + _headers._WWWAuthenticate = default(StringValues); + if((tempBits & ~268435456L) == 0) + { + return; + } + tempBits &= ~268435456L; + } + + if ((tempBits & 536870912L) != 0) + { + _headers._AccessControlAllowCredentials = default(StringValues); + if((tempBits & ~536870912L) == 0) + { + return; + } + tempBits &= ~536870912L; + } + + if ((tempBits & 1073741824L) != 0) + { + _headers._AccessControlAllowHeaders = default(StringValues); + if((tempBits & ~1073741824L) == 0) + { + return; + } + tempBits &= ~1073741824L; + } + + if ((tempBits & 2147483648L) != 0) + { + _headers._AccessControlAllowMethods = default(StringValues); + if((tempBits & ~2147483648L) == 0) + { + return; + } + tempBits &= ~2147483648L; + } + + if ((tempBits & 4294967296L) != 0) + { + _headers._AccessControlAllowOrigin = default(StringValues); + if((tempBits & ~4294967296L) == 0) + { + return; + } + tempBits &= ~4294967296L; + } + + if ((tempBits & 8589934592L) != 0) + { + _headers._AccessControlExposeHeaders = default(StringValues); + if((tempBits & ~8589934592L) == 0) + { + return; + } + tempBits &= ~8589934592L; + } + + if ((tempBits & 17179869184L) != 0) + { + _headers._AccessControlMaxAge = default(StringValues); + if((tempBits & ~17179869184L) == 0) + { + return; + } + tempBits &= ~17179869184L; + } + + } + + protected override bool CopyToFast(KeyValuePair[] array, int arrayIndex) + { + if (arrayIndex < 0) + { + return false; + } + + if ((_bits & 1L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Cache-Control", _headers._CacheControl); + ++arrayIndex; + } + if ((_bits & 2L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Connection", _headers._Connection); + ++arrayIndex; + } + if ((_bits & 4L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Date", _headers._Date); + ++arrayIndex; + } + if ((_bits & 8L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Keep-Alive", _headers._KeepAlive); + ++arrayIndex; + } + if ((_bits & 16L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Pragma", _headers._Pragma); + ++arrayIndex; + } + if ((_bits & 32L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Trailer", _headers._Trailer); + ++arrayIndex; + } + if ((_bits & 64L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Transfer-Encoding", _headers._TransferEncoding); + ++arrayIndex; + } + if ((_bits & 128L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Upgrade", _headers._Upgrade); + ++arrayIndex; + } + if ((_bits & 256L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Via", _headers._Via); + ++arrayIndex; + } + if ((_bits & 512L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Warning", _headers._Warning); + ++arrayIndex; + } + if ((_bits & 1024L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Allow", _headers._Allow); + ++arrayIndex; + } + if ((_bits & 2048L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Type", _headers._ContentType); + ++arrayIndex; + } + if ((_bits & 4096L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Encoding", _headers._ContentEncoding); + ++arrayIndex; + } + if ((_bits & 8192L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Language", _headers._ContentLanguage); + ++arrayIndex; + } + if ((_bits & 16384L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Location", _headers._ContentLocation); + ++arrayIndex; + } + if ((_bits & 32768L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-MD5", _headers._ContentMD5); + ++arrayIndex; + } + if ((_bits & 65536L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Range", _headers._ContentRange); + ++arrayIndex; + } + if ((_bits & 131072L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Expires", _headers._Expires); + ++arrayIndex; + } + if ((_bits & 262144L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Last-Modified", _headers._LastModified); + ++arrayIndex; + } + if ((_bits & 524288L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Accept-Ranges", _headers._AcceptRanges); + ++arrayIndex; + } + if ((_bits & 1048576L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Age", _headers._Age); + ++arrayIndex; + } + if ((_bits & 2097152L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("ETag", _headers._ETag); + ++arrayIndex; + } + if ((_bits & 4194304L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Location", _headers._Location); + ++arrayIndex; + } + if ((_bits & 8388608L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Proxy-Authenticate", _headers._ProxyAuthenticate); + ++arrayIndex; + } + if ((_bits & 16777216L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Retry-After", _headers._RetryAfter); + ++arrayIndex; + } + if ((_bits & 33554432L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Server", _headers._Server); + ++arrayIndex; + } + if ((_bits & 67108864L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Set-Cookie", _headers._SetCookie); + ++arrayIndex; + } + if ((_bits & 134217728L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Vary", _headers._Vary); + ++arrayIndex; + } + if ((_bits & 268435456L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("WWW-Authenticate", _headers._WWWAuthenticate); + ++arrayIndex; + } + if ((_bits & 536870912L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Allow-Credentials", _headers._AccessControlAllowCredentials); + ++arrayIndex; + } + if ((_bits & 1073741824L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Allow-Headers", _headers._AccessControlAllowHeaders); + ++arrayIndex; + } + if ((_bits & 2147483648L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Allow-Methods", _headers._AccessControlAllowMethods); + ++arrayIndex; + } + if ((_bits & 4294967296L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Allow-Origin", _headers._AccessControlAllowOrigin); + ++arrayIndex; + } + if ((_bits & 8589934592L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Expose-Headers", _headers._AccessControlExposeHeaders); + ++arrayIndex; + } + if ((_bits & 17179869184L) != 0) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Access-Control-Max-Age", _headers._AccessControlMaxAge); + ++arrayIndex; + } + if (_contentLength.HasValue) + { + if (arrayIndex == array.Length) + { + return false; + } + array[arrayIndex] = new KeyValuePair("Content-Length", HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + ++arrayIndex; + } + ((ICollection>)MaybeUnknown)?.CopyTo(array, arrayIndex); + + return true; + } + + internal void CopyToFast(ref BufferWriter output) + { + var tempBits = _bits | (_contentLength.HasValue ? -9223372036854775808L : 0); + + if ((tempBits & 2L) != 0) + { + if (_headers._rawConnection != null) + { + output.Write(_headers._rawConnection); + } + else + { + var valueCount = _headers._Connection.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Connection[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 17, 14)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~2L) == 0) + { + return; + } + tempBits &= ~2L; + } + if ((tempBits & 4L) != 0) + { + if (_headers._rawDate != null) + { + output.Write(_headers._rawDate); + } + else + { + var valueCount = _headers._Date.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Date[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 31, 8)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~4L) == 0) + { + return; + } + tempBits &= ~4L; + } + if ((tempBits & 2048L) != 0) + { + { + var valueCount = _headers._ContentType.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentType[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 133, 16)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~2048L) == 0) + { + return; + } + tempBits &= ~2048L; + } + if ((tempBits & 33554432L) != 0) + { + if (_headers._rawServer != null) + { + output.Write(_headers._rawServer); + } + else + { + var valueCount = _headers._Server.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Server[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 350, 10)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~33554432L) == 0) + { + return; + } + tempBits &= ~33554432L; + } + if ((tempBits & -9223372036854775808L) != 0) + { + output.Write(new ReadOnlySpan(_headerBytes, 592, 18)); + PipelineExtensions.WriteNumeric(ref output, (ulong)ContentLength.Value); + + if((tempBits & ~-9223372036854775808L) == 0) + { + return; + } + tempBits &= ~-9223372036854775808L; + } + if ((tempBits & 1L) != 0) + { + { + var valueCount = _headers._CacheControl.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._CacheControl[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 0, 17)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~1L) == 0) + { + return; + } + tempBits &= ~1L; + } + if ((tempBits & 8L) != 0) + { + { + var valueCount = _headers._KeepAlive.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._KeepAlive[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 39, 14)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~8L) == 0) + { + return; + } + tempBits &= ~8L; + } + if ((tempBits & 16L) != 0) + { + { + var valueCount = _headers._Pragma.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Pragma[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 53, 10)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~16L) == 0) + { + return; + } + tempBits &= ~16L; + } + if ((tempBits & 32L) != 0) + { + { + var valueCount = _headers._Trailer.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Trailer[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 63, 11)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~32L) == 0) + { + return; + } + tempBits &= ~32L; + } + if ((tempBits & 64L) != 0) + { + if (_headers._rawTransferEncoding != null) + { + output.Write(_headers._rawTransferEncoding); + } + else + { + var valueCount = _headers._TransferEncoding.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._TransferEncoding[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 74, 21)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~64L) == 0) + { + return; + } + tempBits &= ~64L; + } + if ((tempBits & 128L) != 0) + { + { + var valueCount = _headers._Upgrade.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Upgrade[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 95, 11)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~128L) == 0) + { + return; + } + tempBits &= ~128L; + } + if ((tempBits & 256L) != 0) + { + { + var valueCount = _headers._Via.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Via[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 106, 7)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~256L) == 0) + { + return; + } + tempBits &= ~256L; + } + if ((tempBits & 512L) != 0) + { + { + var valueCount = _headers._Warning.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Warning[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 113, 11)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~512L) == 0) + { + return; + } + tempBits &= ~512L; + } + if ((tempBits & 1024L) != 0) + { + { + var valueCount = _headers._Allow.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Allow[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 124, 9)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~1024L) == 0) + { + return; + } + tempBits &= ~1024L; + } + if ((tempBits & 4096L) != 0) + { + { + var valueCount = _headers._ContentEncoding.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentEncoding[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 149, 20)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~4096L) == 0) + { + return; + } + tempBits &= ~4096L; + } + if ((tempBits & 8192L) != 0) + { + { + var valueCount = _headers._ContentLanguage.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentLanguage[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 169, 20)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~8192L) == 0) + { + return; + } + tempBits &= ~8192L; + } + if ((tempBits & 16384L) != 0) + { + { + var valueCount = _headers._ContentLocation.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentLocation[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 189, 20)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~16384L) == 0) + { + return; + } + tempBits &= ~16384L; + } + if ((tempBits & 32768L) != 0) + { + { + var valueCount = _headers._ContentMD5.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentMD5[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 209, 15)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~32768L) == 0) + { + return; + } + tempBits &= ~32768L; + } + if ((tempBits & 65536L) != 0) + { + { + var valueCount = _headers._ContentRange.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ContentRange[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 224, 17)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~65536L) == 0) + { + return; + } + tempBits &= ~65536L; + } + if ((tempBits & 131072L) != 0) + { + { + var valueCount = _headers._Expires.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Expires[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 241, 11)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~131072L) == 0) + { + return; + } + tempBits &= ~131072L; + } + if ((tempBits & 262144L) != 0) + { + { + var valueCount = _headers._LastModified.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._LastModified[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 252, 17)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~262144L) == 0) + { + return; + } + tempBits &= ~262144L; + } + if ((tempBits & 524288L) != 0) + { + { + var valueCount = _headers._AcceptRanges.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AcceptRanges[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 269, 17)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~524288L) == 0) + { + return; + } + tempBits &= ~524288L; + } + if ((tempBits & 1048576L) != 0) + { + { + var valueCount = _headers._Age.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Age[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 286, 7)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~1048576L) == 0) + { + return; + } + tempBits &= ~1048576L; + } + if ((tempBits & 2097152L) != 0) + { + { + var valueCount = _headers._ETag.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ETag[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 293, 8)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~2097152L) == 0) + { + return; + } + tempBits &= ~2097152L; + } + if ((tempBits & 4194304L) != 0) + { + { + var valueCount = _headers._Location.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Location[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 301, 12)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~4194304L) == 0) + { + return; + } + tempBits &= ~4194304L; + } + if ((tempBits & 8388608L) != 0) + { + { + var valueCount = _headers._ProxyAuthenticate.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._ProxyAuthenticate[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 313, 22)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~8388608L) == 0) + { + return; + } + tempBits &= ~8388608L; + } + if ((tempBits & 16777216L) != 0) + { + { + var valueCount = _headers._RetryAfter.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._RetryAfter[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 335, 15)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~16777216L) == 0) + { + return; + } + tempBits &= ~16777216L; + } + if ((tempBits & 67108864L) != 0) + { + { + var valueCount = _headers._SetCookie.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._SetCookie[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 360, 14)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~67108864L) == 0) + { + return; + } + tempBits &= ~67108864L; + } + if ((tempBits & 134217728L) != 0) + { + { + var valueCount = _headers._Vary.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._Vary[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 374, 8)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~134217728L) == 0) + { + return; + } + tempBits &= ~134217728L; + } + if ((tempBits & 268435456L) != 0) + { + { + var valueCount = _headers._WWWAuthenticate.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._WWWAuthenticate[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 382, 20)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~268435456L) == 0) + { + return; + } + tempBits &= ~268435456L; + } + if ((tempBits & 536870912L) != 0) + { + { + var valueCount = _headers._AccessControlAllowCredentials.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlAllowCredentials[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 402, 36)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~536870912L) == 0) + { + return; + } + tempBits &= ~536870912L; + } + if ((tempBits & 1073741824L) != 0) + { + { + var valueCount = _headers._AccessControlAllowHeaders.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlAllowHeaders[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 438, 32)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~1073741824L) == 0) + { + return; + } + tempBits &= ~1073741824L; + } + if ((tempBits & 2147483648L) != 0) + { + { + var valueCount = _headers._AccessControlAllowMethods.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlAllowMethods[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 470, 32)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~2147483648L) == 0) + { + return; + } + tempBits &= ~2147483648L; + } + if ((tempBits & 4294967296L) != 0) + { + { + var valueCount = _headers._AccessControlAllowOrigin.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlAllowOrigin[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 502, 31)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~4294967296L) == 0) + { + return; + } + tempBits &= ~4294967296L; + } + if ((tempBits & 8589934592L) != 0) + { + { + var valueCount = _headers._AccessControlExposeHeaders.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlExposeHeaders[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 533, 33)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~8589934592L) == 0) + { + return; + } + tempBits &= ~8589934592L; + } + if ((tempBits & 17179869184L) != 0) + { + { + var valueCount = _headers._AccessControlMaxAge.Count; + for (var i = 0; i < valueCount; i++) + { + var value = _headers._AccessControlMaxAge[i]; + if (value != null) + { + output.Write(new ReadOnlySpan(_headerBytes, 566, 26)); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + } + } + } + + if((tempBits & ~17179869184L) == 0) + { + return; + } + tempBits &= ~17179869184L; + } + } + + + private struct HeaderReferences + { + public StringValues _CacheControl; + public StringValues _Connection; + public StringValues _Date; + public StringValues _KeepAlive; + public StringValues _Pragma; + public StringValues _Trailer; + public StringValues _TransferEncoding; + public StringValues _Upgrade; + public StringValues _Via; + public StringValues _Warning; + public StringValues _Allow; + public StringValues _ContentType; + public StringValues _ContentEncoding; + public StringValues _ContentLanguage; + public StringValues _ContentLocation; + public StringValues _ContentMD5; + public StringValues _ContentRange; + public StringValues _Expires; + public StringValues _LastModified; + public StringValues _AcceptRanges; + public StringValues _Age; + public StringValues _ETag; + public StringValues _Location; + public StringValues _ProxyAuthenticate; + public StringValues _RetryAfter; + public StringValues _Server; + public StringValues _SetCookie; + public StringValues _Vary; + public StringValues _WWWAuthenticate; + public StringValues _AccessControlAllowCredentials; + public StringValues _AccessControlAllowHeaders; + public StringValues _AccessControlAllowMethods; + public StringValues _AccessControlAllowOrigin; + public StringValues _AccessControlExposeHeaders; + public StringValues _AccessControlMaxAge; + + public byte[] _rawConnection; + public byte[] _rawDate; + public byte[] _rawTransferEncoding; + public byte[] _rawServer; + } + + public partial struct Enumerator + { + public bool MoveNext() + { + switch (_state) + { + + case 0: + goto state0; + + case 1: + goto state1; + + case 2: + goto state2; + + case 3: + goto state3; + + case 4: + goto state4; + + case 5: + goto state5; + + case 6: + goto state6; + + case 7: + goto state7; + + case 8: + goto state8; + + case 9: + goto state9; + + case 10: + goto state10; + + case 11: + goto state11; + + case 12: + goto state12; + + case 13: + goto state13; + + case 14: + goto state14; + + case 15: + goto state15; + + case 16: + goto state16; + + case 17: + goto state17; + + case 18: + goto state18; + + case 19: + goto state19; + + case 20: + goto state20; + + case 21: + goto state21; + + case 22: + goto state22; + + case 23: + goto state23; + + case 24: + goto state24; + + case 25: + goto state25; + + case 26: + goto state26; + + case 27: + goto state27; + + case 28: + goto state28; + + case 29: + goto state29; + + case 30: + goto state30; + + case 31: + goto state31; + + case 32: + goto state32; + + case 33: + goto state33; + + case 34: + goto state34; + + case 36: + goto state36; + default: + goto state_default; + } + + state0: + if ((_bits & 1L) != 0) + { + _current = new KeyValuePair("Cache-Control", _collection._headers._CacheControl); + _state = 1; + return true; + } + + state1: + if ((_bits & 2L) != 0) + { + _current = new KeyValuePair("Connection", _collection._headers._Connection); + _state = 2; + return true; + } + + state2: + if ((_bits & 4L) != 0) + { + _current = new KeyValuePair("Date", _collection._headers._Date); + _state = 3; + return true; + } + + state3: + if ((_bits & 8L) != 0) + { + _current = new KeyValuePair("Keep-Alive", _collection._headers._KeepAlive); + _state = 4; + return true; + } + + state4: + if ((_bits & 16L) != 0) + { + _current = new KeyValuePair("Pragma", _collection._headers._Pragma); + _state = 5; + return true; + } + + state5: + if ((_bits & 32L) != 0) + { + _current = new KeyValuePair("Trailer", _collection._headers._Trailer); + _state = 6; + return true; + } + + state6: + if ((_bits & 64L) != 0) + { + _current = new KeyValuePair("Transfer-Encoding", _collection._headers._TransferEncoding); + _state = 7; + return true; + } + + state7: + if ((_bits & 128L) != 0) + { + _current = new KeyValuePair("Upgrade", _collection._headers._Upgrade); + _state = 8; + return true; + } + + state8: + if ((_bits & 256L) != 0) + { + _current = new KeyValuePair("Via", _collection._headers._Via); + _state = 9; + return true; + } + + state9: + if ((_bits & 512L) != 0) + { + _current = new KeyValuePair("Warning", _collection._headers._Warning); + _state = 10; + return true; + } + + state10: + if ((_bits & 1024L) != 0) + { + _current = new KeyValuePair("Allow", _collection._headers._Allow); + _state = 11; + return true; + } + + state11: + if ((_bits & 2048L) != 0) + { + _current = new KeyValuePair("Content-Type", _collection._headers._ContentType); + _state = 12; + return true; + } + + state12: + if ((_bits & 4096L) != 0) + { + _current = new KeyValuePair("Content-Encoding", _collection._headers._ContentEncoding); + _state = 13; + return true; + } + + state13: + if ((_bits & 8192L) != 0) + { + _current = new KeyValuePair("Content-Language", _collection._headers._ContentLanguage); + _state = 14; + return true; + } + + state14: + if ((_bits & 16384L) != 0) + { + _current = new KeyValuePair("Content-Location", _collection._headers._ContentLocation); + _state = 15; + return true; + } + + state15: + if ((_bits & 32768L) != 0) + { + _current = new KeyValuePair("Content-MD5", _collection._headers._ContentMD5); + _state = 16; + return true; + } + + state16: + if ((_bits & 65536L) != 0) + { + _current = new KeyValuePair("Content-Range", _collection._headers._ContentRange); + _state = 17; + return true; + } + + state17: + if ((_bits & 131072L) != 0) + { + _current = new KeyValuePair("Expires", _collection._headers._Expires); + _state = 18; + return true; + } + + state18: + if ((_bits & 262144L) != 0) + { + _current = new KeyValuePair("Last-Modified", _collection._headers._LastModified); + _state = 19; + return true; + } + + state19: + if ((_bits & 524288L) != 0) + { + _current = new KeyValuePair("Accept-Ranges", _collection._headers._AcceptRanges); + _state = 20; + return true; + } + + state20: + if ((_bits & 1048576L) != 0) + { + _current = new KeyValuePair("Age", _collection._headers._Age); + _state = 21; + return true; + } + + state21: + if ((_bits & 2097152L) != 0) + { + _current = new KeyValuePair("ETag", _collection._headers._ETag); + _state = 22; + return true; + } + + state22: + if ((_bits & 4194304L) != 0) + { + _current = new KeyValuePair("Location", _collection._headers._Location); + _state = 23; + return true; + } + + state23: + if ((_bits & 8388608L) != 0) + { + _current = new KeyValuePair("Proxy-Authenticate", _collection._headers._ProxyAuthenticate); + _state = 24; + return true; + } + + state24: + if ((_bits & 16777216L) != 0) + { + _current = new KeyValuePair("Retry-After", _collection._headers._RetryAfter); + _state = 25; + return true; + } + + state25: + if ((_bits & 33554432L) != 0) + { + _current = new KeyValuePair("Server", _collection._headers._Server); + _state = 26; + return true; + } + + state26: + if ((_bits & 67108864L) != 0) + { + _current = new KeyValuePair("Set-Cookie", _collection._headers._SetCookie); + _state = 27; + return true; + } + + state27: + if ((_bits & 134217728L) != 0) + { + _current = new KeyValuePair("Vary", _collection._headers._Vary); + _state = 28; + return true; + } + + state28: + if ((_bits & 268435456L) != 0) + { + _current = new KeyValuePair("WWW-Authenticate", _collection._headers._WWWAuthenticate); + _state = 29; + return true; + } + + state29: + if ((_bits & 536870912L) != 0) + { + _current = new KeyValuePair("Access-Control-Allow-Credentials", _collection._headers._AccessControlAllowCredentials); + _state = 30; + return true; + } + + state30: + if ((_bits & 1073741824L) != 0) + { + _current = new KeyValuePair("Access-Control-Allow-Headers", _collection._headers._AccessControlAllowHeaders); + _state = 31; + return true; + } + + state31: + if ((_bits & 2147483648L) != 0) + { + _current = new KeyValuePair("Access-Control-Allow-Methods", _collection._headers._AccessControlAllowMethods); + _state = 32; + return true; + } + + state32: + if ((_bits & 4294967296L) != 0) + { + _current = new KeyValuePair("Access-Control-Allow-Origin", _collection._headers._AccessControlAllowOrigin); + _state = 33; + return true; + } + + state33: + if ((_bits & 8589934592L) != 0) + { + _current = new KeyValuePair("Access-Control-Expose-Headers", _collection._headers._AccessControlExposeHeaders); + _state = 34; + return true; + } + + state34: + if ((_bits & 17179869184L) != 0) + { + _current = new KeyValuePair("Access-Control-Max-Age", _collection._headers._AccessControlMaxAge); + _state = 35; + return true; + } + + state36: + if (_collection._contentLength.HasValue) + { + _current = new KeyValuePair("Content-Length", HeaderUtilities.FormatNonNegativeInt64(_collection._contentLength.Value)); + _state = 37; + return true; + } + state_default: + if (!_hasUnknown || !_unknownEnumerator.MoveNext()) + { + _current = default(KeyValuePair); + return false; + } + _current = _unknownEnumerator.Current; + return true; + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.cs new file mode 100644 index 0000000000..444ac62774 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.cs @@ -0,0 +1,444 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public abstract class HttpHeaders : IHeaderDictionary + { + protected long? _contentLength; + protected bool _isReadOnly; + protected Dictionary MaybeUnknown; + protected Dictionary Unknown => MaybeUnknown ?? (MaybeUnknown = new Dictionary(StringComparer.OrdinalIgnoreCase)); + + public long? ContentLength + { + get { return _contentLength; } + set + { + if (value.HasValue && value.Value < 0) + { + ThrowInvalidContentLengthException(value.Value); + } + _contentLength = value; + } + } + + StringValues IHeaderDictionary.this[string key] + { + get + { + StringValues value; + TryGetValueFast(key, out value); + return value; + } + set + { + if (_isReadOnly) + { + ThrowHeadersReadOnlyException(); + } + if (value.Count == 0) + { + RemoveFast(key); + } + else + { + SetValueFast(key, value); + } + } + } + + StringValues IDictionary.this[string key] + { + get + { + // Unlike the IHeaderDictionary version, this getter will throw a KeyNotFoundException. + StringValues value; + if (!TryGetValueFast(key, out value)) + { + ThrowKeyNotFoundException(); + } + return value; + } + set + { + ((IHeaderDictionary)this)[key] = value; + } + } + + protected void ThrowHeadersReadOnlyException() + { + throw new InvalidOperationException(CoreStrings.HeadersAreReadOnly); + } + + protected void ThrowArgumentException() + { + throw new ArgumentException(); + } + + protected void ThrowKeyNotFoundException() + { + throw new KeyNotFoundException(); + } + + protected void ThrowDuplicateKeyException() + { + throw new ArgumentException(CoreStrings.KeyAlreadyExists); + } + + int ICollection>.Count => GetCountFast(); + + bool ICollection>.IsReadOnly => _isReadOnly; + + ICollection IDictionary.Keys => ((IDictionary)this).Select(pair => pair.Key).ToList(); + + ICollection IDictionary.Values => ((IDictionary)this).Select(pair => pair.Value).ToList(); + + public void SetReadOnly() + { + _isReadOnly = true; + } + + public void Reset() + { + _isReadOnly = false; + ClearFast(); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + protected static StringValues AppendValue(in StringValues existing, string append) + { + return StringValues.Concat(existing, append); + } + + protected static int BitCount(long value) + { + // see https://github.com/dotnet/corefx/blob/5965fd3756bc9dd9c89a27621eb10c6931126de2/src/System.Reflection.Metadata/src/System/Reflection/Internal/Utilities/BitArithmetic.cs + + const ulong Mask01010101 = 0x5555555555555555UL; + const ulong Mask00110011 = 0x3333333333333333UL; + const ulong Mask00001111 = 0x0F0F0F0F0F0F0F0FUL; + const ulong Mask00000001 = 0x0101010101010101UL; + + var v = (ulong)value; + + v = v - ((v >> 1) & Mask01010101); + v = (v & Mask00110011) + ((v >> 2) & Mask00110011); + return (int)(unchecked(((v + (v >> 4)) & Mask00001111) * Mask00000001) >> 56); + } + + protected virtual int GetCountFast() + { throw new NotImplementedException(); } + + protected virtual bool TryGetValueFast(string key, out StringValues value) + { throw new NotImplementedException(); } + + protected virtual void SetValueFast(string key, in StringValues value) + { throw new NotImplementedException(); } + + protected virtual bool AddValueFast(string key, in StringValues value) + { throw new NotImplementedException(); } + + protected virtual bool RemoveFast(string key) + { throw new NotImplementedException(); } + + protected virtual void ClearFast() + { throw new NotImplementedException(); } + + protected virtual bool CopyToFast(KeyValuePair[] array, int arrayIndex) + { throw new NotImplementedException(); } + + protected virtual IEnumerator> GetEnumeratorFast() + { throw new NotImplementedException(); } + + void ICollection>.Add(KeyValuePair item) + { + ((IDictionary)this).Add(item.Key, item.Value); + } + + void IDictionary.Add(string key, StringValues value) + { + if (_isReadOnly) + { + ThrowHeadersReadOnlyException(); + } + + if (value.Count > 0 && !AddValueFast(key, value)) + { + ThrowDuplicateKeyException(); + } + } + + void ICollection>.Clear() + { + if (_isReadOnly) + { + ThrowHeadersReadOnlyException(); + } + ClearFast(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + StringValues value; + return + TryGetValueFast(item.Key, out value) && + value.Equals(item.Value); + } + + bool IDictionary.ContainsKey(string key) + { + StringValues value; + return TryGetValueFast(key, out value); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + if (!CopyToFast(array, arrayIndex)) + { + ThrowArgumentException(); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumeratorFast(); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return GetEnumeratorFast(); + } + + bool ICollection>.Remove(KeyValuePair item) + { + StringValues value; + return + TryGetValueFast(item.Key, out value) && + value.Equals(item.Value) && + RemoveFast(item.Key); + } + + bool IDictionary.Remove(string key) + { + if (_isReadOnly) + { + ThrowHeadersReadOnlyException(); + } + return RemoveFast(key); + } + + bool IDictionary.TryGetValue(string key, out StringValues value) + { + return TryGetValueFast(key, out value); + } + + public static void ValidateHeaderCharacters(in StringValues headerValues) + { + var count = headerValues.Count; + for (var i = 0; i < count; i++) + + { + ValidateHeaderCharacters(headerValues[i]); + } + } + + public static void ValidateHeaderCharacters(string headerCharacters) + { + if (headerCharacters != null) + { + foreach (var ch in headerCharacters) + { + if (ch < 0x20 || ch > 0x7E) + { + ThrowInvalidHeaderCharacter(ch); + } + } + } + } + + public static unsafe ConnectionOptions ParseConnection(in StringValues connection) + { + var connectionOptions = ConnectionOptions.None; + + var connectionCount = connection.Count; + for (var i = 0; i < connectionCount; i++) + { + var value = connection[i]; + fixed (char* ptr = value) + { + var ch = ptr; + var tokenEnd = ch; + var end = ch + value.Length; + + while (ch < end) + { + while (tokenEnd < end && *tokenEnd != ',') + { + tokenEnd++; + } + + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + var tokenLength = tokenEnd - ch; + + if (tokenLength >= 9 && (*ch | 0x20) == 'k') + { + if ((*++ch | 0x20) == 'e' && + (*++ch | 0x20) == 'e' && + (*++ch | 0x20) == 'p' && + *++ch == '-' && + (*++ch | 0x20) == 'a' && + (*++ch | 0x20) == 'l' && + (*++ch | 0x20) == 'i' && + (*++ch | 0x20) == 'v' && + (*++ch | 0x20) == 'e') + { + ch++; + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + if (ch == tokenEnd) + { + connectionOptions |= ConnectionOptions.KeepAlive; + } + } + } + else if (tokenLength >= 7 && (*ch | 0x20) == 'u') + { + if ((*++ch | 0x20) == 'p' && + (*++ch | 0x20) == 'g' && + (*++ch | 0x20) == 'r' && + (*++ch | 0x20) == 'a' && + (*++ch | 0x20) == 'd' && + (*++ch | 0x20) == 'e') + { + ch++; + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + if (ch == tokenEnd) + { + connectionOptions |= ConnectionOptions.Upgrade; + } + } + } + else if (tokenLength >= 5 && (*ch | 0x20) == 'c') + { + if ((*++ch | 0x20) == 'l' && + (*++ch | 0x20) == 'o' && + (*++ch | 0x20) == 's' && + (*++ch | 0x20) == 'e') + { + ch++; + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + if (ch == tokenEnd) + { + connectionOptions |= ConnectionOptions.Close; + } + } + } + + tokenEnd++; + ch = tokenEnd; + } + } + } + + return connectionOptions; + } + + public static unsafe TransferCoding GetFinalTransferCoding(in StringValues transferEncoding) + { + var transferEncodingOptions = TransferCoding.None; + + var transferEncodingCount = transferEncoding.Count; + for (var i = 0; i < transferEncodingCount; i++) + { + var value = transferEncoding[i]; + fixed (char* ptr = value) + { + var ch = ptr; + var tokenEnd = ch; + var end = ch + value.Length; + + while (ch < end) + { + while (tokenEnd < end && *tokenEnd != ',') + { + tokenEnd++; + } + + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + var tokenLength = tokenEnd - ch; + + if (tokenLength >= 7 && (*ch | 0x20) == 'c') + { + if ((*++ch | 0x20) == 'h' && + (*++ch | 0x20) == 'u' && + (*++ch | 0x20) == 'n' && + (*++ch | 0x20) == 'k' && + (*++ch | 0x20) == 'e' && + (*++ch | 0x20) == 'd') + { + ch++; + while (ch < tokenEnd && *ch == ' ') + { + ch++; + } + + if (ch == tokenEnd) + { + transferEncodingOptions = TransferCoding.Chunked; + } + } + } + + if (tokenLength > 0 && ch != tokenEnd) + { + transferEncodingOptions = TransferCoding.Other; + } + + tokenEnd++; + ch = tokenEnd; + } + } + } + + return transferEncodingOptions; + } + + private static void ThrowInvalidContentLengthException(long value) + { + throw new ArgumentOutOfRangeException(CoreStrings.FormatInvalidContentLength_InvalidNumber(value)); + } + + private static void ThrowInvalidHeaderCharacter(char ch) + { + throw new InvalidOperationException(CoreStrings.FormatInvalidAsciiOrControlChar(string.Format("0x{0:X4}", (ushort)ch))); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpMethod.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpMethod.cs new file mode 100644 index 0000000000..3e6ff0667e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpMethod.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum HttpMethod: byte + { + Get, + Put, + Delete, + Post, + Head, + Trace, + Patch, + Connect, + Options, + + Custom, + + None = byte.MaxValue, + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs new file mode 100644 index 0000000000..b3d8d90d9c --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpParser.cs @@ -0,0 +1,509 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public class HttpParser : IHttpParser where TRequestHandler : IHttpHeadersHandler, IHttpRequestLineHandler + { + private bool _showErrorDetails; + + public HttpParser() : this(showErrorDetails: true) + { + } + + public HttpParser(bool showErrorDetails) + { + _showErrorDetails = showErrorDetails; + } + + // byte types don't have a data type annotation so we pre-cast them; to avoid in-place casts + private const byte ByteCR = (byte)'\r'; + private const byte ByteLF = (byte)'\n'; + private const byte ByteColon = (byte)':'; + private const byte ByteSpace = (byte)' '; + private const byte ByteTab = (byte)'\t'; + private const byte ByteQuestionMark = (byte)'?'; + private const byte BytePercentage = (byte)'%'; + + public unsafe bool ParseRequestLine(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.End; + + // Prepare the first span + var span = buffer.First.Span; + var lineIndex = span.IndexOf(ByteLF); + if (lineIndex >= 0) + { + consumed = buffer.GetPosition(lineIndex + 1, consumed); + span = span.Slice(0, lineIndex + 1); + } + else if (buffer.IsSingleSegment) + { + // No request line end + return false; + } + else if (TryGetNewLine(buffer, out var found)) + { + span = buffer.Slice(consumed, found).ToSpan(); + consumed = found; + } + else + { + // No request line end + return false; + } + + // Fix and parse the span + fixed (byte* data = &MemoryMarshal.GetReference(span)) + { + ParseRequestLine(handler, data, span.Length); + } + + examined = consumed; + return true; + } + + private unsafe void ParseRequestLine(TRequestHandler handler, byte* data, int length) + { + int offset; + // Get Method and set the offset + var method = HttpUtilities.GetKnownMethod(data, length, out offset); + + Span customMethod = method == HttpMethod.Custom ? + GetUnknownMethod(data, length, out offset) : + default; + + // Skip space + offset++; + + byte ch = 0; + // Target = Path and Query + var pathEncoded = false; + var pathStart = -1; + for (; offset < length; offset++) + { + ch = data[offset]; + if (ch == ByteSpace) + { + if (pathStart == -1) + { + // Empty path is illegal + RejectRequestLine(data, length); + } + + break; + } + else if (ch == ByteQuestionMark) + { + if (pathStart == -1) + { + // Empty path is illegal + RejectRequestLine(data, length); + } + + break; + } + else if (ch == BytePercentage) + { + if (pathStart == -1) + { + // Path starting with % is illegal + RejectRequestLine(data, length); + } + + pathEncoded = true; + } + else if (pathStart == -1) + { + pathStart = offset; + } + } + + if (pathStart == -1) + { + // Start of path not found + RejectRequestLine(data, length); + } + + var pathBuffer = new Span(data + pathStart, offset - pathStart); + + // Query string + var queryStart = offset; + if (ch == ByteQuestionMark) + { + // We have a query string + for (; offset < length; offset++) + { + ch = data[offset]; + if (ch == ByteSpace) + { + break; + } + } + } + + // End of query string not found + if (offset == length) + { + RejectRequestLine(data, length); + } + + var targetBuffer = new Span(data + pathStart, offset - pathStart); + var query = new Span(data + queryStart, offset - queryStart); + + // Consume space + offset++; + + // Version + var httpVersion = HttpUtilities.GetKnownVersion(data + offset, length - offset); + if (httpVersion == HttpVersion.Unknown) + { + if (data[offset] == ByteCR || data[length - 2] != ByteCR) + { + // If missing delimiter or CR before LF, reject and log entire line + RejectRequestLine(data, length); + } + else + { + // else inform HTTP version is unsupported. + RejectUnknownVersion(data + offset, length - offset - 2); + } + } + + // After version's 8 bytes and CR, expect LF + if (data[offset + 8 + 1] != ByteLF) + { + RejectRequestLine(data, length); + } + + handler.OnStartLine(method, httpVersion, targetBuffer, pathBuffer, query, customMethod, pathEncoded); + } + + public unsafe bool ParseHeaders(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined, out int consumedBytes) + { + consumed = buffer.Start; + examined = buffer.End; + consumedBytes = 0; + + var bufferEnd = buffer.End; + + var reader = new BufferReader(buffer); + var start = default(BufferReader); + var done = false; + + try + { + while (!reader.End) + { + var span = reader.CurrentSegment; + var remaining = span.Length - reader.CurrentSegmentIndex; + + fixed (byte* pBuffer = &MemoryMarshal.GetReference(span)) + { + while (remaining > 0) + { + var index = reader.CurrentSegmentIndex; + int ch1; + int ch2; + var readAhead = false; + + // Fast path, we're still looking at the same span + if (remaining >= 2) + { + ch1 = pBuffer[index]; + ch2 = pBuffer[index + 1]; + } + else + { + // Store the reader before we look ahead 2 bytes (probably straddling + // spans) + start = reader; + + // Possibly split across spans + ch1 = reader.Read(); + ch2 = reader.Read(); + + readAhead = true; + } + + if (ch1 == ByteCR) + { + // Check for final CRLF. + if (ch2 == -1) + { + // Reset the reader so we don't consume anything + reader = start; + return false; + } + else if (ch2 == ByteLF) + { + // If we got 2 bytes from the span directly so skip ahead 2 so that + // the reader's state matches what we expect + if (!readAhead) + { + reader.Advance(2); + } + + done = true; + return true; + } + + // Headers don't end in CRLF line. + BadHttpRequestException.Throw(RequestRejectionReason.InvalidRequestHeadersNoCRLF); + } + + // We moved the reader so look ahead 2 bytes so reset both the reader + // and the index + if (readAhead) + { + reader = start; + index = reader.CurrentSegmentIndex; + } + + var endIndex = new Span(pBuffer + index, remaining).IndexOf(ByteLF); + var length = 0; + + if (endIndex != -1) + { + length = endIndex + 1; + var pHeader = pBuffer + index; + + TakeSingleHeader(pHeader, length, handler); + } + else + { + var current = reader.Position; + var currentSlice = buffer.Slice(current, bufferEnd); + + var lineEndPosition = currentSlice.PositionOf(ByteLF); + // Split buffers + if (lineEndPosition == null) + { + // Not there + return false; + } + + var lineEnd = lineEndPosition.Value; + + // Make sure LF is included in lineEnd + lineEnd = buffer.GetPosition(1, lineEnd); + var headerSpan = buffer.Slice(current, lineEnd).ToSpan(); + length = headerSpan.Length; + + fixed (byte* pHeader = &MemoryMarshal.GetReference(headerSpan)) + { + TakeSingleHeader(pHeader, length, handler); + } + + // We're going to the next span after this since we know we crossed spans here + // so mark the remaining as equal to the headerSpan so that we end up at 0 + // on the next iteration + remaining = length; + } + + // Skip the reader forward past the header line + reader.Advance(length); + remaining -= length; + } + } + } + + return false; + } + finally + { + consumed = reader.Position; + consumedBytes = reader.ConsumedBytes; + + if (done) + { + examined = consumed; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe int FindEndOfName(byte* headerLine, int length) + { + var index = 0; + var sawWhitespace = false; + for (; index < length; index++) + { + var ch = headerLine[index]; + if (ch == ByteColon) + { + break; + } + if (ch == ByteTab || ch == ByteSpace || ch == ByteCR) + { + sawWhitespace = true; + } + } + + if (index == length || sawWhitespace) + { + RejectRequestHeader(headerLine, length); + } + + return index; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe void TakeSingleHeader(byte* headerLine, int length, TRequestHandler handler) + { + // Skip CR, LF from end position + var valueEnd = length - 3; + var nameEnd = FindEndOfName(headerLine, length); + + if (headerLine[valueEnd + 2] != ByteLF) + { + RejectRequestHeader(headerLine, length); + } + if (headerLine[valueEnd + 1] != ByteCR) + { + RejectRequestHeader(headerLine, length); + } + + // Skip colon from value start + var valueStart = nameEnd + 1; + // Ignore start whitespace + for (; valueStart < valueEnd; valueStart++) + { + var ch = headerLine[valueStart]; + if (ch != ByteTab && ch != ByteSpace && ch != ByteCR) + { + break; + } + else if (ch == ByteCR) + { + RejectRequestHeader(headerLine, length); + } + } + + // Check for CR in value + var valueBuffer = new Span(headerLine + valueStart, valueEnd - valueStart + 1); + if (valueBuffer.IndexOf(ByteCR) >= 0) + { + RejectRequestHeader(headerLine, length); + } + + // Ignore end whitespace + var lengthChanged = false; + for (; valueEnd >= valueStart; valueEnd--) + { + var ch = headerLine[valueEnd]; + if (ch != ByteTab && ch != ByteSpace) + { + break; + } + + lengthChanged = true; + } + + if (lengthChanged) + { + // Length changed + valueBuffer = new Span(headerLine + valueStart, valueEnd - valueStart + 1); + } + + var nameBuffer = new Span(headerLine, nameEnd); + + handler.OnHeader(nameBuffer, valueBuffer); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool TryGetNewLine(in ReadOnlySequence buffer, out SequencePosition found) + { + var byteLfPosition = buffer.PositionOf(ByteLF); + if (byteLfPosition != null) + { + // Move 1 byte past the \n + found = buffer.GetPosition(1, byteLfPosition.Value); + return true; + } + + found = default; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe Span GetUnknownMethod(byte* data, int length, out int methodLength) + { + methodLength = 0; + for (var i = 0; i < length; i++) + { + var ch = data[i]; + + if (ch == ByteSpace) + { + if (i == 0) + { + RejectRequestLine(data, length); + } + + methodLength = i; + break; + } + else if (!IsValidTokenChar((char)ch)) + { + RejectRequestLine(data, length); + } + } + + return new Span(data, methodLength); + } + + private static bool IsValidTokenChar(char c) + { + // Determines if a character is valid as a 'token' as defined in the + // HTTP spec: https://tools.ietf.org/html/rfc7230#section-3.2.6 + return + (c >= '0' && c <= '9') || + (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + c == '!' || + c == '#' || + c == '$' || + c == '%' || + c == '&' || + c == '\'' || + c == '*' || + c == '+' || + c == '-' || + c == '.' || + c == '^' || + c == '_' || + c == '`' || + c == '|' || + c == '~'; + } + + [StackTraceHidden] + private unsafe void RejectRequestLine(byte* requestLine, int length) + => throw GetInvalidRequestException(RequestRejectionReason.InvalidRequestLine, requestLine, length); + + [StackTraceHidden] + private unsafe void RejectRequestHeader(byte* headerLine, int length) + => throw GetInvalidRequestException(RequestRejectionReason.InvalidRequestHeader, headerLine, length); + + [StackTraceHidden] + private unsafe void RejectUnknownVersion(byte* version, int length) + => throw GetInvalidRequestException(RequestRejectionReason.UnrecognizedHTTPVersion, version, length); + + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe BadHttpRequestException GetInvalidRequestException(RequestRejectionReason reason, byte* detail, int length) + => BadHttpRequestException.GetException( + reason, + _showErrorDetails + ? new Span(detail, length).GetAsciiStringEscaped(Constants.MaxExceptionDetailSize) + : string.Empty); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs new file mode 100644 index 0000000000..118a4ca97e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -0,0 +1,290 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class HttpProtocol : IFeatureCollection, + IHttpRequestFeature, + IHttpResponseFeature, + IHttpConnectionFeature, + IHttpRequestLifetimeFeature, + IHttpRequestIdentifierFeature, + IHttpBodyControlFeature, + IHttpMaxRequestBodySizeFeature, + IHttpMinRequestBodyDataRateFeature, + IHttpMinResponseDataRateFeature + { + // NOTE: When feature interfaces are added to or removed from this HttpProtocol class implementation, + // then the list of `implementedFeatures` in the generated code project MUST also be updated. + // See also: tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/HttpProtocolFeatureCollection.cs + + private int _featureRevision; + + private List> MaybeExtra; + + public void ResetFeatureCollection() + { + FastReset(); + MaybeExtra?.Clear(); + _featureRevision++; + } + + private object ExtraFeatureGet(Type key) + { + if (MaybeExtra == null) + { + return null; + } + for (var i = 0; i < MaybeExtra.Count; i++) + { + var kv = MaybeExtra[i]; + if (kv.Key == key) + { + return kv.Value; + } + } + return null; + } + + private void ExtraFeatureSet(Type key, object value) + { + if (MaybeExtra == null) + { + MaybeExtra = new List>(2); + } + + for (var i = 0; i < MaybeExtra.Count; i++) + { + if (MaybeExtra[i].Key == key) + { + MaybeExtra[i] = new KeyValuePair(key, value); + return; + } + } + MaybeExtra.Add(new KeyValuePair(key, value)); + } + + string IHttpRequestFeature.Protocol + { + get => HttpVersion; + set => HttpVersion = value; + } + + string IHttpRequestFeature.Scheme + { + get => Scheme ?? "http"; + set => Scheme = value; + } + + string IHttpRequestFeature.Method + { + get + { + if (_methodText != null) + { + return _methodText; + } + + _methodText = HttpUtilities.MethodToString(Method) ?? string.Empty; + return _methodText; + } + set + { + _methodText = value; + } + } + + string IHttpRequestFeature.PathBase + { + get => PathBase ?? ""; + set => PathBase = value; + } + + string IHttpRequestFeature.Path + { + get => Path; + set => Path = value; + } + + string IHttpRequestFeature.QueryString + { + get => QueryString; + set => QueryString = value; + } + + string IHttpRequestFeature.RawTarget + { + get => RawTarget; + set => RawTarget = value; + } + + IHeaderDictionary IHttpRequestFeature.Headers + { + get => RequestHeaders; + set => RequestHeaders = value; + } + + Stream IHttpRequestFeature.Body + { + get => RequestBody; + set => RequestBody = value; + } + + int IHttpResponseFeature.StatusCode + { + get => StatusCode; + set => StatusCode = value; + } + + string IHttpResponseFeature.ReasonPhrase + { + get => ReasonPhrase; + set => ReasonPhrase = value; + } + + IHeaderDictionary IHttpResponseFeature.Headers + { + get => ResponseHeaders; + set => ResponseHeaders = value; + } + + Stream IHttpResponseFeature.Body + { + get => ResponseBody; + set => ResponseBody = value; + } + + CancellationToken IHttpRequestLifetimeFeature.RequestAborted + { + get => RequestAborted; + set => RequestAborted = value; + } + + bool IHttpResponseFeature.HasStarted => HasResponseStarted; + + bool IFeatureCollection.IsReadOnly => false; + + int IFeatureCollection.Revision => _featureRevision; + + IPAddress IHttpConnectionFeature.RemoteIpAddress + { + get => RemoteIpAddress; + set => RemoteIpAddress = value; + } + + IPAddress IHttpConnectionFeature.LocalIpAddress + { + get => LocalIpAddress; + set => LocalIpAddress = value; + } + + int IHttpConnectionFeature.RemotePort + { + get => RemotePort; + set => RemotePort = value; + } + + int IHttpConnectionFeature.LocalPort + { + get => LocalPort; + set => LocalPort = value; + } + + string IHttpConnectionFeature.ConnectionId + { + get => ConnectionIdFeature; + set => ConnectionIdFeature = value; + } + + string IHttpRequestIdentifierFeature.TraceIdentifier + { + get => TraceIdentifier; + set => TraceIdentifier = value; + } + + bool IHttpBodyControlFeature.AllowSynchronousIO + { + get => AllowSynchronousIO; + set => AllowSynchronousIO = value; + } + + bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody || IsUpgraded; + + long? IHttpMaxRequestBodySizeFeature.MaxRequestBodySize + { + get => MaxRequestBodySize; + set + { + if (HasStartedConsumingRequestBody) + { + throw new InvalidOperationException(CoreStrings.MaxRequestBodySizeCannotBeModifiedAfterRead); + } + if (IsUpgraded) + { + throw new InvalidOperationException(CoreStrings.MaxRequestBodySizeCannotBeModifiedForUpgradedRequests); + } + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.NonNegativeNumberOrNullRequired); + } + + MaxRequestBodySize = value; + } + } + + MinDataRate IHttpMinRequestBodyDataRateFeature.MinDataRate + { + get => MinRequestBodyDataRate; + set => MinRequestBodyDataRate = value; + } + + MinDataRate IHttpMinResponseDataRateFeature.MinDataRate + { + get => MinResponseDataRate; + set => MinResponseDataRate = value; + } + + protected void ResetIHttpUpgradeFeature() + { + _currentIHttpUpgradeFeature = this; + } + + protected void ResetIHttp2StreamIdFeature() + { + _currentIHttp2StreamIdFeature = this; + } + + void IHttpResponseFeature.OnStarting(Func callback, object state) + { + OnStarting(callback, state); + } + + void IHttpResponseFeature.OnCompleted(Func callback, object state) + { + OnCompleted(callback, state); + } + + IEnumerator> IEnumerable>.GetEnumerator() => FastEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => FastEnumerable().GetEnumerator(); + + void IHttpRequestLifetimeFeature.Abort() + { + Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier); + Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication)); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs new file mode 100644 index 0000000000..3bac0a0778 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs @@ -0,0 +1,566 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Features.Authentication; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class HttpProtocol + { + private static readonly Type IHttpRequestFeatureType = typeof(IHttpRequestFeature); + private static readonly Type IHttpResponseFeatureType = typeof(IHttpResponseFeature); + private static readonly Type IHttpRequestIdentifierFeatureType = typeof(IHttpRequestIdentifierFeature); + private static readonly Type IServiceProvidersFeatureType = typeof(IServiceProvidersFeature); + private static readonly Type IHttpRequestLifetimeFeatureType = typeof(IHttpRequestLifetimeFeature); + private static readonly Type IHttpConnectionFeatureType = typeof(IHttpConnectionFeature); + private static readonly Type IHttpAuthenticationFeatureType = typeof(IHttpAuthenticationFeature); + private static readonly Type IQueryFeatureType = typeof(IQueryFeature); + private static readonly Type IFormFeatureType = typeof(IFormFeature); + private static readonly Type IHttpUpgradeFeatureType = typeof(IHttpUpgradeFeature); + private static readonly Type IHttp2StreamIdFeatureType = typeof(IHttp2StreamIdFeature); + private static readonly Type IResponseCookiesFeatureType = typeof(IResponseCookiesFeature); + private static readonly Type IItemsFeatureType = typeof(IItemsFeature); + private static readonly Type ITlsConnectionFeatureType = typeof(ITlsConnectionFeature); + private static readonly Type IHttpWebSocketFeatureType = typeof(IHttpWebSocketFeature); + private static readonly Type ISessionFeatureType = typeof(ISessionFeature); + private static readonly Type IHttpMaxRequestBodySizeFeatureType = typeof(IHttpMaxRequestBodySizeFeature); + private static readonly Type IHttpMinRequestBodyDataRateFeatureType = typeof(IHttpMinRequestBodyDataRateFeature); + private static readonly Type IHttpMinResponseDataRateFeatureType = typeof(IHttpMinResponseDataRateFeature); + private static readonly Type IHttpBodyControlFeatureType = typeof(IHttpBodyControlFeature); + private static readonly Type IHttpSendFileFeatureType = typeof(IHttpSendFileFeature); + + private object _currentIHttpRequestFeature; + private object _currentIHttpResponseFeature; + private object _currentIHttpRequestIdentifierFeature; + private object _currentIServiceProvidersFeature; + private object _currentIHttpRequestLifetimeFeature; + private object _currentIHttpConnectionFeature; + private object _currentIHttpAuthenticationFeature; + private object _currentIQueryFeature; + private object _currentIFormFeature; + private object _currentIHttpUpgradeFeature; + private object _currentIHttp2StreamIdFeature; + private object _currentIResponseCookiesFeature; + private object _currentIItemsFeature; + private object _currentITlsConnectionFeature; + private object _currentIHttpWebSocketFeature; + private object _currentISessionFeature; + private object _currentIHttpMaxRequestBodySizeFeature; + private object _currentIHttpMinRequestBodyDataRateFeature; + private object _currentIHttpMinResponseDataRateFeature; + private object _currentIHttpBodyControlFeature; + private object _currentIHttpSendFileFeature; + + private void FastReset() + { + _currentIHttpRequestFeature = this; + _currentIHttpResponseFeature = this; + _currentIHttpRequestIdentifierFeature = this; + _currentIHttpRequestLifetimeFeature = this; + _currentIHttpConnectionFeature = this; + _currentIHttpMaxRequestBodySizeFeature = this; + _currentIHttpMinRequestBodyDataRateFeature = this; + _currentIHttpMinResponseDataRateFeature = this; + _currentIHttpBodyControlFeature = this; + + _currentIServiceProvidersFeature = null; + _currentIHttpAuthenticationFeature = null; + _currentIQueryFeature = null; + _currentIFormFeature = null; + _currentIHttpUpgradeFeature = null; + _currentIHttp2StreamIdFeature = null; + _currentIResponseCookiesFeature = null; + _currentIItemsFeature = null; + _currentITlsConnectionFeature = null; + _currentIHttpWebSocketFeature = null; + _currentISessionFeature = null; + _currentIHttpSendFileFeature = null; + } + + object IFeatureCollection.this[Type key] + { + get + { + object feature = null; + if (key == IHttpRequestFeatureType) + { + feature = _currentIHttpRequestFeature; + } + else if (key == IHttpResponseFeatureType) + { + feature = _currentIHttpResponseFeature; + } + else if (key == IHttpRequestIdentifierFeatureType) + { + feature = _currentIHttpRequestIdentifierFeature; + } + else if (key == IServiceProvidersFeatureType) + { + feature = _currentIServiceProvidersFeature; + } + else if (key == IHttpRequestLifetimeFeatureType) + { + feature = _currentIHttpRequestLifetimeFeature; + } + else if (key == IHttpConnectionFeatureType) + { + feature = _currentIHttpConnectionFeature; + } + else if (key == IHttpAuthenticationFeatureType) + { + feature = _currentIHttpAuthenticationFeature; + } + else if (key == IQueryFeatureType) + { + feature = _currentIQueryFeature; + } + else if (key == IFormFeatureType) + { + feature = _currentIFormFeature; + } + else if (key == IHttpUpgradeFeatureType) + { + feature = _currentIHttpUpgradeFeature; + } + else if (key == IHttp2StreamIdFeatureType) + { + feature = _currentIHttp2StreamIdFeature; + } + else if (key == IResponseCookiesFeatureType) + { + feature = _currentIResponseCookiesFeature; + } + else if (key == IItemsFeatureType) + { + feature = _currentIItemsFeature; + } + else if (key == ITlsConnectionFeatureType) + { + feature = _currentITlsConnectionFeature; + } + else if (key == IHttpWebSocketFeatureType) + { + feature = _currentIHttpWebSocketFeature; + } + else if (key == ISessionFeatureType) + { + feature = _currentISessionFeature; + } + else if (key == IHttpMaxRequestBodySizeFeatureType) + { + feature = _currentIHttpMaxRequestBodySizeFeature; + } + else if (key == IHttpMinRequestBodyDataRateFeatureType) + { + feature = _currentIHttpMinRequestBodyDataRateFeature; + } + else if (key == IHttpMinResponseDataRateFeatureType) + { + feature = _currentIHttpMinResponseDataRateFeature; + } + else if (key == IHttpBodyControlFeatureType) + { + feature = _currentIHttpBodyControlFeature; + } + else if (key == IHttpSendFileFeatureType) + { + feature = _currentIHttpSendFileFeature; + } + else if (MaybeExtra != null) + { + feature = ExtraFeatureGet(key); + } + + return feature ?? ConnectionFeatures[key]; + } + + set + { + _featureRevision++; + + if (key == IHttpRequestFeatureType) + { + _currentIHttpRequestFeature = value; + } + else if (key == IHttpResponseFeatureType) + { + _currentIHttpResponseFeature = value; + } + else if (key == IHttpRequestIdentifierFeatureType) + { + _currentIHttpRequestIdentifierFeature = value; + } + else if (key == IServiceProvidersFeatureType) + { + _currentIServiceProvidersFeature = value; + } + else if (key == IHttpRequestLifetimeFeatureType) + { + _currentIHttpRequestLifetimeFeature = value; + } + else if (key == IHttpConnectionFeatureType) + { + _currentIHttpConnectionFeature = value; + } + else if (key == IHttpAuthenticationFeatureType) + { + _currentIHttpAuthenticationFeature = value; + } + else if (key == IQueryFeatureType) + { + _currentIQueryFeature = value; + } + else if (key == IFormFeatureType) + { + _currentIFormFeature = value; + } + else if (key == IHttpUpgradeFeatureType) + { + _currentIHttpUpgradeFeature = value; + } + else if (key == IHttp2StreamIdFeatureType) + { + _currentIHttp2StreamIdFeature = value; + } + else if (key == IResponseCookiesFeatureType) + { + _currentIResponseCookiesFeature = value; + } + else if (key == IItemsFeatureType) + { + _currentIItemsFeature = value; + } + else if (key == ITlsConnectionFeatureType) + { + _currentITlsConnectionFeature = value; + } + else if (key == IHttpWebSocketFeatureType) + { + _currentIHttpWebSocketFeature = value; + } + else if (key == ISessionFeatureType) + { + _currentISessionFeature = value; + } + else if (key == IHttpMaxRequestBodySizeFeatureType) + { + _currentIHttpMaxRequestBodySizeFeature = value; + } + else if (key == IHttpMinRequestBodyDataRateFeatureType) + { + _currentIHttpMinRequestBodyDataRateFeature = value; + } + else if (key == IHttpMinResponseDataRateFeatureType) + { + _currentIHttpMinResponseDataRateFeature = value; + } + else if (key == IHttpBodyControlFeatureType) + { + _currentIHttpBodyControlFeature = value; + } + else if (key == IHttpSendFileFeatureType) + { + _currentIHttpSendFileFeature = value; + } + else + { + ExtraFeatureSet(key, value); + } + } + } + + void IFeatureCollection.Set(TFeature feature) + { + _featureRevision++; + if (typeof(TFeature) == typeof(IHttpRequestFeature)) + { + _currentIHttpRequestFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpResponseFeature)) + { + _currentIHttpResponseFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpRequestIdentifierFeature)) + { + _currentIHttpRequestIdentifierFeature = feature; + } + else if (typeof(TFeature) == typeof(IServiceProvidersFeature)) + { + _currentIServiceProvidersFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpRequestLifetimeFeature)) + { + _currentIHttpRequestLifetimeFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpConnectionFeature)) + { + _currentIHttpConnectionFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpAuthenticationFeature)) + { + _currentIHttpAuthenticationFeature = feature; + } + else if (typeof(TFeature) == typeof(IQueryFeature)) + { + _currentIQueryFeature = feature; + } + else if (typeof(TFeature) == typeof(IFormFeature)) + { + _currentIFormFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpUpgradeFeature)) + { + _currentIHttpUpgradeFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttp2StreamIdFeature)) + { + _currentIHttp2StreamIdFeature = feature; + } + else if (typeof(TFeature) == typeof(IResponseCookiesFeature)) + { + _currentIResponseCookiesFeature = feature; + } + else if (typeof(TFeature) == typeof(IItemsFeature)) + { + _currentIItemsFeature = feature; + } + else if (typeof(TFeature) == typeof(ITlsConnectionFeature)) + { + _currentITlsConnectionFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpWebSocketFeature)) + { + _currentIHttpWebSocketFeature = feature; + } + else if (typeof(TFeature) == typeof(ISessionFeature)) + { + _currentISessionFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpMaxRequestBodySizeFeature)) + { + _currentIHttpMaxRequestBodySizeFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpMinRequestBodyDataRateFeature)) + { + _currentIHttpMinRequestBodyDataRateFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpMinResponseDataRateFeature)) + { + _currentIHttpMinResponseDataRateFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpBodyControlFeature)) + { + _currentIHttpBodyControlFeature = feature; + } + else if (typeof(TFeature) == typeof(IHttpSendFileFeature)) + { + _currentIHttpSendFileFeature = feature; + } + else + { + ExtraFeatureSet(typeof(TFeature), feature); + } + } + + TFeature IFeatureCollection.Get() + { + TFeature feature = default; + if (typeof(TFeature) == typeof(IHttpRequestFeature)) + { + feature = (TFeature)_currentIHttpRequestFeature; + } + else if (typeof(TFeature) == typeof(IHttpResponseFeature)) + { + feature = (TFeature)_currentIHttpResponseFeature; + } + else if (typeof(TFeature) == typeof(IHttpRequestIdentifierFeature)) + { + feature = (TFeature)_currentIHttpRequestIdentifierFeature; + } + else if (typeof(TFeature) == typeof(IServiceProvidersFeature)) + { + feature = (TFeature)_currentIServiceProvidersFeature; + } + else if (typeof(TFeature) == typeof(IHttpRequestLifetimeFeature)) + { + feature = (TFeature)_currentIHttpRequestLifetimeFeature; + } + else if (typeof(TFeature) == typeof(IHttpConnectionFeature)) + { + feature = (TFeature)_currentIHttpConnectionFeature; + } + else if (typeof(TFeature) == typeof(IHttpAuthenticationFeature)) + { + feature = (TFeature)_currentIHttpAuthenticationFeature; + } + else if (typeof(TFeature) == typeof(IQueryFeature)) + { + feature = (TFeature)_currentIQueryFeature; + } + else if (typeof(TFeature) == typeof(IFormFeature)) + { + feature = (TFeature)_currentIFormFeature; + } + else if (typeof(TFeature) == typeof(IHttpUpgradeFeature)) + { + feature = (TFeature)_currentIHttpUpgradeFeature; + } + else if (typeof(TFeature) == typeof(IHttp2StreamIdFeature)) + { + feature = (TFeature)_currentIHttp2StreamIdFeature; + } + else if (typeof(TFeature) == typeof(IResponseCookiesFeature)) + { + feature = (TFeature)_currentIResponseCookiesFeature; + } + else if (typeof(TFeature) == typeof(IItemsFeature)) + { + feature = (TFeature)_currentIItemsFeature; + } + else if (typeof(TFeature) == typeof(ITlsConnectionFeature)) + { + feature = (TFeature)_currentITlsConnectionFeature; + } + else if (typeof(TFeature) == typeof(IHttpWebSocketFeature)) + { + feature = (TFeature)_currentIHttpWebSocketFeature; + } + else if (typeof(TFeature) == typeof(ISessionFeature)) + { + feature = (TFeature)_currentISessionFeature; + } + else if (typeof(TFeature) == typeof(IHttpMaxRequestBodySizeFeature)) + { + feature = (TFeature)_currentIHttpMaxRequestBodySizeFeature; + } + else if (typeof(TFeature) == typeof(IHttpMinRequestBodyDataRateFeature)) + { + feature = (TFeature)_currentIHttpMinRequestBodyDataRateFeature; + } + else if (typeof(TFeature) == typeof(IHttpMinResponseDataRateFeature)) + { + feature = (TFeature)_currentIHttpMinResponseDataRateFeature; + } + else if (typeof(TFeature) == typeof(IHttpBodyControlFeature)) + { + feature = (TFeature)_currentIHttpBodyControlFeature; + } + else if (typeof(TFeature) == typeof(IHttpSendFileFeature)) + { + feature = (TFeature)_currentIHttpSendFileFeature; + } + else if (MaybeExtra != null) + { + feature = (TFeature)(ExtraFeatureGet(typeof(TFeature))); + } + + if (feature == null) + { + feature = ConnectionFeatures.Get(); + } + + return feature; + } + + private IEnumerable> FastEnumerable() + { + if (_currentIHttpRequestFeature != null) + { + yield return new KeyValuePair(IHttpRequestFeatureType, _currentIHttpRequestFeature as IHttpRequestFeature); + } + if (_currentIHttpResponseFeature != null) + { + yield return new KeyValuePair(IHttpResponseFeatureType, _currentIHttpResponseFeature as IHttpResponseFeature); + } + if (_currentIHttpRequestIdentifierFeature != null) + { + yield return new KeyValuePair(IHttpRequestIdentifierFeatureType, _currentIHttpRequestIdentifierFeature as IHttpRequestIdentifierFeature); + } + if (_currentIServiceProvidersFeature != null) + { + yield return new KeyValuePair(IServiceProvidersFeatureType, _currentIServiceProvidersFeature as IServiceProvidersFeature); + } + if (_currentIHttpRequestLifetimeFeature != null) + { + yield return new KeyValuePair(IHttpRequestLifetimeFeatureType, _currentIHttpRequestLifetimeFeature as IHttpRequestLifetimeFeature); + } + if (_currentIHttpConnectionFeature != null) + { + yield return new KeyValuePair(IHttpConnectionFeatureType, _currentIHttpConnectionFeature as IHttpConnectionFeature); + } + if (_currentIHttpAuthenticationFeature != null) + { + yield return new KeyValuePair(IHttpAuthenticationFeatureType, _currentIHttpAuthenticationFeature as IHttpAuthenticationFeature); + } + if (_currentIQueryFeature != null) + { + yield return new KeyValuePair(IQueryFeatureType, _currentIQueryFeature as IQueryFeature); + } + if (_currentIFormFeature != null) + { + yield return new KeyValuePair(IFormFeatureType, _currentIFormFeature as IFormFeature); + } + if (_currentIHttpUpgradeFeature != null) + { + yield return new KeyValuePair(IHttpUpgradeFeatureType, _currentIHttpUpgradeFeature as IHttpUpgradeFeature); + } + if (_currentIHttp2StreamIdFeature != null) + { + yield return new KeyValuePair(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature as IHttp2StreamIdFeature); + } + if (_currentIResponseCookiesFeature != null) + { + yield return new KeyValuePair(IResponseCookiesFeatureType, _currentIResponseCookiesFeature as IResponseCookiesFeature); + } + if (_currentIItemsFeature != null) + { + yield return new KeyValuePair(IItemsFeatureType, _currentIItemsFeature as IItemsFeature); + } + if (_currentITlsConnectionFeature != null) + { + yield return new KeyValuePair(ITlsConnectionFeatureType, _currentITlsConnectionFeature as ITlsConnectionFeature); + } + if (_currentIHttpWebSocketFeature != null) + { + yield return new KeyValuePair(IHttpWebSocketFeatureType, _currentIHttpWebSocketFeature as IHttpWebSocketFeature); + } + if (_currentISessionFeature != null) + { + yield return new KeyValuePair(ISessionFeatureType, _currentISessionFeature as ISessionFeature); + } + if (_currentIHttpMaxRequestBodySizeFeature != null) + { + yield return new KeyValuePair(IHttpMaxRequestBodySizeFeatureType, _currentIHttpMaxRequestBodySizeFeature as IHttpMaxRequestBodySizeFeature); + } + if (_currentIHttpMinRequestBodyDataRateFeature != null) + { + yield return new KeyValuePair(IHttpMinRequestBodyDataRateFeatureType, _currentIHttpMinRequestBodyDataRateFeature as IHttpMinRequestBodyDataRateFeature); + } + if (_currentIHttpMinResponseDataRateFeature != null) + { + yield return new KeyValuePair(IHttpMinResponseDataRateFeatureType, _currentIHttpMinResponseDataRateFeature as IHttpMinResponseDataRateFeature); + } + if (_currentIHttpBodyControlFeature != null) + { + yield return new KeyValuePair(IHttpBodyControlFeatureType, _currentIHttpBodyControlFeature as IHttpBodyControlFeature); + } + if (_currentIHttpSendFileFeature != null) + { + yield return new KeyValuePair(IHttpSendFileFeatureType, _currentIHttpSendFileFeature as IHttpSendFileFeature); + } + + if (MaybeExtra != null) + { + foreach(var item in MaybeExtra) + { + yield return item; + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs new file mode 100644 index 0000000000..b0ae93147d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -0,0 +1,1353 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Net; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +// ReSharper disable AccessToModifiedClosure + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public abstract partial class HttpProtocol : IHttpResponseControl + { + private static readonly byte[] _bytesConnectionClose = Encoding.ASCII.GetBytes("\r\nConnection: close"); + private static readonly byte[] _bytesConnectionKeepAlive = Encoding.ASCII.GetBytes("\r\nConnection: keep-alive"); + private static readonly byte[] _bytesTransferEncodingChunked = Encoding.ASCII.GetBytes("\r\nTransfer-Encoding: chunked"); + private static readonly byte[] _bytesServer = Encoding.ASCII.GetBytes("\r\nServer: " + Constants.ServerName); + private static readonly Func, long> _writeChunk = WriteChunk; + + private readonly object _onStartingSync = new Object(); + private readonly object _onCompletedSync = new Object(); + + protected Streams _streams; + + private Stack, object>> _onStarting; + private Stack, object>> _onCompleted; + + private int _requestAborted; + private volatile int _ioCompleted; + private CancellationTokenSource _abortedCts; + private CancellationToken? _manuallySetRequestAbortToken; + + protected RequestProcessingStatus _requestProcessingStatus; + protected volatile bool _keepAlive; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx + protected bool _upgradeAvailable; + private bool _canHaveBody; + private bool _autoChunk; + private Exception _applicationException; + private BadHttpRequestException _requestRejectedException; + + protected HttpVersion _httpVersion; + + private string _requestId; + private int _requestHeadersParsed; + + private long _responseBytesWritten; + + private readonly IHttpProtocolContext _context; + + protected string _methodText = null; + private string _scheme = null; + + public HttpProtocol(IHttpProtocolContext context) + { + _context = context; + + ServerOptions = ServiceContext.ServerOptions; + HttpResponseControl = this; + RequestBodyPipe = CreateRequestBodyPipe(); + } + + public IHttpResponseControl HttpResponseControl { get; set; } + + public Pipe RequestBodyPipe { get; } + + public ServiceContext ServiceContext => _context.ServiceContext; + private IPEndPoint LocalEndPoint => _context.LocalEndPoint; + private IPEndPoint RemoteEndPoint => _context.RemoteEndPoint; + + public IFeatureCollection ConnectionFeatures => _context.ConnectionFeatures; + public IHttpOutputProducer Output { get; protected set; } + + protected IKestrelTrace Log => ServiceContext.Log; + private DateHeaderValueManager DateHeaderValueManager => ServiceContext.DateHeaderValueManager; + // Hold direct reference to ServerOptions since this is used very often in the request processing path + protected KestrelServerOptions ServerOptions { get; } + protected string ConnectionId => _context.ConnectionId; + + public string ConnectionIdFeature { get; set; } + public bool HasStartedConsumingRequestBody { get; set; } + public long? MaxRequestBodySize { get; set; } + public bool AllowSynchronousIO { get; set; } + + /// + /// The request id. + /// + public string TraceIdentifier + { + set => _requestId = value; + get + { + // don't generate an ID until it is requested + if (_requestId == null) + { + _requestId = CreateRequestId(); + } + return _requestId; + } + } + + public abstract bool IsUpgradableRequest { get; } + public bool IsUpgraded { get; set; } + public IPAddress RemoteIpAddress { get; set; } + public int RemotePort { get; set; } + public IPAddress LocalIpAddress { get; set; } + public int LocalPort { get; set; } + public string Scheme { get; set; } + public HttpMethod Method { get; set; } + public string PathBase { get; set; } + public string Path { get; set; } + public string QueryString { get; set; } + public string RawTarget { get; set; } + + public string HttpVersion + { + get + { + if (_httpVersion == Http.HttpVersion.Http11) + { + return HttpUtilities.Http11Version; + } + if (_httpVersion == Http.HttpVersion.Http10) + { + return HttpUtilities.Http10Version; + } + if (_httpVersion == Http.HttpVersion.Http2) + { + return HttpUtilities.Http2Version; + } + + return string.Empty; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + set + { + // GetKnownVersion returns versions which ReferenceEquals interned string + // As most common path, check for this only in fast-path and inline + if (ReferenceEquals(value, HttpUtilities.Http11Version)) + { + _httpVersion = Http.HttpVersion.Http11; + } + else if (ReferenceEquals(value, HttpUtilities.Http10Version)) + { + _httpVersion = Http.HttpVersion.Http10; + } + else if (ReferenceEquals(value, HttpUtilities.Http2Version)) + { + _httpVersion = Http.HttpVersion.Http2; + } + else + { + HttpVersionSetSlow(value); + } + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void HttpVersionSetSlow(string value) + { + if (value == HttpUtilities.Http11Version) + { + _httpVersion = Http.HttpVersion.Http11; + } + else if (value == HttpUtilities.Http10Version) + { + _httpVersion = Http.HttpVersion.Http10; + } + else if (value == HttpUtilities.Http2Version) + { + _httpVersion = Http.HttpVersion.Http2; + } + else + { + _httpVersion = Http.HttpVersion.Unknown; + } + } + + public IHeaderDictionary RequestHeaders { get; set; } + public Stream RequestBody { get; set; } + + private int _statusCode; + public int StatusCode + { + get => _statusCode; + set + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(StatusCode)); + } + + _statusCode = value; + } + } + + private string _reasonPhrase; + + public string ReasonPhrase + { + get => _reasonPhrase; + + set + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(ReasonPhrase)); + } + + _reasonPhrase = value; + } + } + + public IHeaderDictionary ResponseHeaders { get; set; } + public Stream ResponseBody { get; set; } + + public CancellationToken RequestAborted + { + get + { + // If a request abort token was previously explicitly set, return it. + if (_manuallySetRequestAbortToken.HasValue) + { + return _manuallySetRequestAbortToken.Value; + } + // Otherwise, get the abort CTS. If we have one, which would mean that someone previously + // asked for the RequestAborted token, simply return its token. If we don't, + // check to see whether we've already aborted, in which case just return an + // already canceled token. Finally, force a source into existence if we still + // don't have one, and return its token. + var cts = _abortedCts; + return + cts != null ? cts.Token : + (_ioCompleted == 1) ? new CancellationToken(true) : + RequestAbortedSource.Token; + } + set + { + // Set an abort token, overriding one we create internally. This setter and associated + // field exist purely to support IHttpRequestLifetimeFeature.set_RequestAborted. + _manuallySetRequestAbortToken = value; + } + } + + private CancellationTokenSource RequestAbortedSource + { + get + { + // Get the abort token, lazily-initializing it if necessary. + // Make sure it's canceled if an abort request already came in. + + // EnsureInitialized can return null since _abortedCts is reset to null + // after it's already been initialized to a non-null value. + // If EnsureInitialized does return null, this property was accessed between + // requests so it's safe to return an ephemeral CancellationTokenSource. + var cts = LazyInitializer.EnsureInitialized(ref _abortedCts, () => new CancellationTokenSource()) + ?? new CancellationTokenSource(); + + if (_ioCompleted == 1) + { + cts.Cancel(); + } + return cts; + } + } + + public bool HasResponseStarted => _requestProcessingStatus == RequestProcessingStatus.ResponseStarted; + + protected HttpRequestHeaders HttpRequestHeaders { get; } = new HttpRequestHeaders(); + + protected HttpResponseHeaders HttpResponseHeaders { get; } = new HttpResponseHeaders(); + + public MinDataRate MinRequestBodyDataRate { get; set; } + + public MinDataRate MinResponseDataRate { get; set; } + + public void InitializeStreams(MessageBody messageBody) + { + if (_streams == null) + { + _streams = new Streams(bodyControl: this, httpResponseControl: this); + } + + (RequestBody, ResponseBody) = _streams.Start(messageBody); + } + + public void StopStreams() => _streams.Stop(); + + // For testing + internal void ResetState() + { + _requestProcessingStatus = RequestProcessingStatus.RequestPending; + } + + public void Reset() + { + _onStarting = null; + _onCompleted = null; + + _requestProcessingStatus = RequestProcessingStatus.RequestPending; + _autoChunk = false; + _applicationException = null; + _requestRejectedException = null; + + ResetFeatureCollection(); + + HasStartedConsumingRequestBody = false; + MaxRequestBodySize = ServerOptions.Limits.MaxRequestBodySize; + AllowSynchronousIO = ServerOptions.AllowSynchronousIO; + TraceIdentifier = null; + Method = HttpMethod.None; + _methodText = null; + PathBase = null; + Path = null; + RawTarget = null; + QueryString = null; + _httpVersion = Http.HttpVersion.Unknown; + _statusCode = StatusCodes.Status200OK; + _reasonPhrase = null; + + var remoteEndPoint = RemoteEndPoint; + RemoteIpAddress = remoteEndPoint?.Address; + RemotePort = remoteEndPoint?.Port ?? 0; + + var localEndPoint = LocalEndPoint; + LocalIpAddress = localEndPoint?.Address; + LocalPort = localEndPoint?.Port ?? 0; + + ConnectionIdFeature = ConnectionId; + + HttpRequestHeaders.Reset(); + HttpResponseHeaders.Reset(); + RequestHeaders = HttpRequestHeaders; + ResponseHeaders = HttpResponseHeaders; + + if (_scheme == null) + { + var tlsFeature = ConnectionFeatures?[typeof(ITlsConnectionFeature)]; + _scheme = tlsFeature != null ? "https" : "http"; + } + + Scheme = _scheme; + + _manuallySetRequestAbortToken = null; + _abortedCts = null; + + // Allow two bytes for \r\n after headers + _requestHeadersParsed = 0; + + _responseBytesWritten = 0; + + MinRequestBodyDataRate = ServerOptions.Limits.MinRequestBodyDataRate; + MinResponseDataRate = ServerOptions.Limits.MinResponseDataRate; + + OnReset(); + } + + protected abstract void OnReset(); + + protected virtual void OnRequestProcessingEnding() + { + } + + protected virtual void OnRequestProcessingEnded() + { + } + + protected virtual void BeginRequestProcessing() + { + } + + protected virtual bool BeginRead(out ValueTask awaitable) + { + awaitable = default; + return false; + } + + protected abstract string CreateRequestId(); + + protected abstract MessageBody CreateMessageBody(); + + protected abstract bool TryParseRequest(ReadResult result, out bool endConnection); + + private void CancelRequestAbortedToken() + { + try + { + RequestAbortedSource.Cancel(); + _abortedCts = null; + } + catch (Exception ex) + { + Log.ApplicationError(ConnectionId, TraceIdentifier, ex); + } + } + + public void OnInputOrOutputCompleted() + { + if (Interlocked.Exchange(ref _ioCompleted, 1) != 0) + { + return; + } + + _keepAlive = false; + + Output.Dispose(); + + // Potentially calling user code. CancelRequestAbortedToken logs any exceptions. + ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this); + } + + /// + /// Immediately kill the connection and poison the request and response streams with an error if there is one. + /// + public void Abort(ConnectionAbortedException abortReason) + { + if (Interlocked.Exchange(ref _requestAborted, 1) != 0) + { + return; + } + + _streams?.Abort(abortReason); + + // Abort output prior to calling OnIOCompleted() to give the transport the chance to + // complete the input with the correct error and message. + Output.Abort(abortReason); + + OnInputOrOutputCompleted(); + } + + public void OnHeader(Span name, Span value) + { + _requestHeadersParsed++; + if (_requestHeadersParsed > ServerOptions.Limits.MaxRequestHeaderCount) + { + BadHttpRequestException.Throw(RequestRejectionReason.TooManyHeaders); + } + var valueString = value.GetAsciiStringNonNullCharacters(); + + HttpRequestHeaders.Append(name, valueString); + } + + public async Task ProcessRequestsAsync(IHttpApplication application) + { + try + { + await ProcessRequests(application); + } + catch (BadHttpRequestException ex) + { + // Handle BadHttpRequestException thrown during request line or header parsing. + // SetBadRequestState logs the error. + SetBadRequestState(ex); + } + catch (ConnectionResetException ex) + { + // Don't log ECONNRESET errors made between requests. Browsers like IE will reset connections regularly. + if (_requestProcessingStatus != RequestProcessingStatus.RequestPending) + { + Log.RequestProcessingError(ConnectionId, ex); + } + } + catch (IOException ex) + { + Log.RequestProcessingError(ConnectionId, ex); + } + catch (Exception ex) + { + Log.LogWarning(0, ex, CoreStrings.RequestProcessingEndError); + } + finally + { + try + { + OnRequestProcessingEnding(); + await TryProduceInvalidRequestResponse(); + + // Prevent RequestAborted from firing. + Reset(); + + Output.Dispose(); + } + catch (Exception ex) + { + Log.LogWarning(0, ex, CoreStrings.ConnectionShutdownError); + } + finally + { + OnRequestProcessingEnded(); + } + } + } + + private async Task ProcessRequests(IHttpApplication application) + { + // Keep-alive is default for HTTP/1.1 and HTTP/2; parsing and errors will change its value + _keepAlive = true; + + while (_keepAlive) + { + BeginRequestProcessing(); + + var result = default(ReadResult); + var endConnection = false; + do + { + if (BeginRead(out var awaitable)) + { + result = await awaitable; + } + } while (!TryParseRequest(result, out endConnection)); + + if (endConnection) + { + // Connection finished, stop processing requests + return; + } + + var messageBody = CreateMessageBody(); + if (!messageBody.RequestKeepAlive) + { + _keepAlive = false; + } + + _upgradeAvailable = messageBody.RequestUpgrade; + + InitializeStreams(messageBody); + + var httpContext = application.CreateContext(this); + + try + { + KestrelEventSource.Log.RequestStart(this); + + // Run the application code for this request + await application.ProcessRequestAsync(httpContext); + + if (_ioCompleted == 0) + { + VerifyResponseContentLength(); + } + } + catch (BadHttpRequestException ex) + { + // Capture BadHttpRequestException for further processing + // This has to be caught here so StatusCode is set properly before disposing the HttpContext + // (DisposeContext logs StatusCode). + SetBadRequestState(ex); + ReportApplicationError(ex); + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + + KestrelEventSource.Log.RequestStop(this); + + // Trigger OnStarting if it hasn't been called yet and the app hasn't + // already failed. If an OnStarting callback throws we can go through + // our normal error handling in ProduceEnd. + // https://github.com/aspnet/KestrelHttpServer/issues/43 + if (!HasResponseStarted && _applicationException == null && _onStarting != null) + { + await FireOnStarting(); + } + + // At this point all user code that needs use to the request or response streams has completed. + // Using these streams in the OnCompleted callback is not allowed. + StopStreams(); + + // 4XX responses are written by TryProduceInvalidRequestResponse during connection tear down. + if (_requestRejectedException == null) + { + if (_ioCompleted == 0) + { + // Call ProduceEnd() before consuming the rest of the request body to prevent + // delaying clients waiting for the chunk terminator: + // + // https://github.com/dotnet/corefx/issues/17330#issuecomment-288248663 + // + // This also prevents the 100 Continue response from being sent if the app + // never tried to read the body. + // https://github.com/aspnet/KestrelHttpServer/issues/2102 + // + // ProduceEnd() must be called before _application.DisposeContext(), to ensure + // HttpContext.Response.StatusCode is correctly set when + // IHttpContextFactory.Dispose(HttpContext) is called. + await ProduceEnd(); + } + else if (!HasResponseStarted) + { + // If the request was aborted and no response was sent, there's no + // meaningful status code to log. + StatusCode = 0; + } + } + + if (_onCompleted != null) + { + await FireOnCompleted(); + } + + application.DisposeContext(httpContext, _applicationException); + + // Even for non-keep-alive requests, try to consume the entire body to avoid RSTs. + if (_ioCompleted == 0 && _requestRejectedException == null && !messageBody.IsEmpty) + { + await messageBody.ConsumeAsync(); + } + + if (HasStartedConsumingRequestBody) + { + RequestBodyPipe.Reader.Complete(); + + // Wait for MessageBody.PumpAsync() to call RequestBodyPipe.Writer.Complete(). + await messageBody.StopAsync(); + + // At this point both the request body pipe reader and writer should be completed. + RequestBodyPipe.Reset(); + } + } + } + + public void OnStarting(Func callback, object state) + { + lock (_onStartingSync) + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(OnStarting)); + } + + if (_onStarting == null) + { + _onStarting = new Stack, object>>(); + } + _onStarting.Push(new KeyValuePair, object>(callback, state)); + } + } + + public void OnCompleted(Func callback, object state) + { + lock (_onCompletedSync) + { + if (_onCompleted == null) + { + _onCompleted = new Stack, object>>(); + } + _onCompleted.Push(new KeyValuePair, object>(callback, state)); + } + } + + protected Task FireOnStarting() + { + Stack, object>> onStarting; + lock (_onStartingSync) + { + onStarting = _onStarting; + _onStarting = null; + } + + if (onStarting == null) + { + return Task.CompletedTask; + } + else + { + return FireOnStartingMayAwait(onStarting); + } + + } + + private Task FireOnStartingMayAwait(Stack, object>> onStarting) + { + try + { + var count = onStarting.Count; + for (var i = 0; i < count; i++) + { + var entry = onStarting.Pop(); + var task = entry.Key.Invoke(entry.Value); + if (!ReferenceEquals(task, Task.CompletedTask)) + { + return FireOnStartingAwaited(task, onStarting); + } + } + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + + return Task.CompletedTask; + } + + private async Task FireOnStartingAwaited(Task currentTask, Stack, object>> onStarting) + { + try + { + await currentTask; + + var count = onStarting.Count; + for (var i = 0; i < count; i++) + { + var entry = onStarting.Pop(); + await entry.Key.Invoke(entry.Value); + } + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + } + + protected Task FireOnCompleted() + { + Stack, object>> onCompleted; + lock (_onCompletedSync) + { + onCompleted = _onCompleted; + _onCompleted = null; + } + + if (onCompleted == null) + { + return Task.CompletedTask; + } + + return FireOnCompletedAwaited(onCompleted); + } + + private async Task FireOnCompletedAwaited(Stack, object>> onCompleted) + { + foreach (var entry in onCompleted) + { + try + { + await entry.Key.Invoke(entry.Value); + } + catch (Exception ex) + { + Log.ApplicationError(ConnectionId, TraceIdentifier, ex); + } + } + } + + public Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + if (!HasResponseStarted) + { + var initializeTask = InitializeResponseAsync(0); + // If return is Task.CompletedTask no awaiting is required + if (!ReferenceEquals(initializeTask, Task.CompletedTask)) + { + return FlushAsyncAwaited(initializeTask, cancellationToken); + } + } + + return Output.FlushAsync(cancellationToken); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private async Task FlushAsyncAwaited(Task initializeTask, CancellationToken cancellationToken) + { + await initializeTask; + await Output.FlushAsync(cancellationToken); + } + + public Task WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken = default(CancellationToken)) + { + // For the first write, ensure headers are flushed if WriteDataAsync isn't called. + var firstWrite = !HasResponseStarted; + + if (firstWrite) + { + var initializeTask = InitializeResponseAsync(data.Length); + // If return is Task.CompletedTask no awaiting is required + if (!ReferenceEquals(initializeTask, Task.CompletedTask)) + { + return WriteAsyncAwaited(initializeTask, data, cancellationToken); + } + } + else + { + VerifyAndUpdateWrite(data.Length); + } + + if (_canHaveBody) + { + if (_autoChunk) + { + if (data.Length == 0) + { + return !firstWrite ? Task.CompletedTask : FlushAsync(cancellationToken); + } + return WriteChunkedAsync(data, cancellationToken); + } + else + { + CheckLastWrite(); + return Output.WriteDataAsync(data.Span, cancellationToken: cancellationToken); + } + } + else + { + HandleNonBodyResponseWrite(); + return !firstWrite ? Task.CompletedTask : FlushAsync(cancellationToken); + } + } + + public async Task WriteAsyncAwaited(Task initializeTask, ReadOnlyMemory data, CancellationToken cancellationToken) + { + await initializeTask; + + // WriteAsyncAwaited is only called for the first write to the body. + // Ensure headers are flushed if Write(Chunked)Async isn't called. + if (_canHaveBody) + { + if (_autoChunk) + { + if (data.Length == 0) + { + await FlushAsync(cancellationToken); + return; + } + + await WriteChunkedAsync(data, cancellationToken); + } + else + { + CheckLastWrite(); + await Output.WriteDataAsync(data.Span, cancellationToken: cancellationToken); + } + } + else + { + HandleNonBodyResponseWrite(); + await FlushAsync(cancellationToken); + } + } + + private void VerifyAndUpdateWrite(int count) + { + var responseHeaders = HttpResponseHeaders; + + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten + count > responseHeaders.ContentLength.Value) + { + _keepAlive = false; + ThrowTooManyBytesWritten(count); + } + + _responseBytesWritten += count; + } + + [StackTraceHidden] + private void ThrowTooManyBytesWritten(int count) + { + throw GetTooManyBytesWrittenException(count); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private InvalidOperationException GetTooManyBytesWrittenException(int count) + { + var responseHeaders = HttpResponseHeaders; + return new InvalidOperationException( + CoreStrings.FormatTooManyBytesWritten(_responseBytesWritten + count, responseHeaders.ContentLength.Value)); + } + + private void CheckLastWrite() + { + var responseHeaders = HttpResponseHeaders; + + // Prevent firing request aborted token if this is the last write, to avoid + // aborting the request if the app is still running when the client receives + // the final bytes of the response and gracefully closes the connection. + // + // Called after VerifyAndUpdateWrite(), so _responseBytesWritten has already been updated. + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten == responseHeaders.ContentLength.Value) + { + _abortedCts = null; + } + } + + protected void VerifyResponseContentLength() + { + var responseHeaders = HttpResponseHeaders; + + if (Method != HttpMethod.Head && + StatusCode != StatusCodes.Status304NotModified && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten < responseHeaders.ContentLength.Value) + { + // We need to close the connection if any bytes were written since the client + // cannot be certain of how many bytes it will receive. + if (_responseBytesWritten > 0) + { + _keepAlive = false; + } + + ReportApplicationError(new InvalidOperationException( + CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value))); + } + } + + private Task WriteChunkedAsync(ReadOnlyMemory data, CancellationToken cancellationToken) + { + return Output.WriteAsync(_writeChunk, data); + } + + private static long WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory buffer) + { + var bytesWritten = 0L; + if (buffer.Length > 0) + { + var writer = new BufferWriter(writableBuffer); + + ChunkWriter.WriteBeginChunkBytes(ref writer, buffer.Length); + writer.Write(buffer.Span); + ChunkWriter.WriteEndChunkBytes(ref writer); + writer.Commit(); + + bytesWritten = writer.BytesCommitted; + } + + return bytesWritten; + } + + private static ArraySegment CreateAsciiByteArraySegment(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + return new ArraySegment(bytes); + } + + public void ProduceContinue() + { + if (HasResponseStarted) + { + return; + } + + if (_httpVersion != Http.HttpVersion.Http10 && + RequestHeaders.TryGetValue("Expect", out var expect) && + (expect.FirstOrDefault() ?? "").Equals("100-continue", StringComparison.OrdinalIgnoreCase)) + { + Output.Write100ContinueAsync(default(CancellationToken)).GetAwaiter().GetResult(); + } + } + + public Task InitializeResponseAsync(int firstWriteByteCount) + { + var startingTask = FireOnStarting(); + // If return is Task.CompletedTask no awaiting is required + if (!ReferenceEquals(startingTask, Task.CompletedTask)) + { + return InitializeResponseAwaited(startingTask, firstWriteByteCount); + } + + if (_applicationException != null) + { + ThrowResponseAbortedException(); + } + + VerifyAndUpdateWrite(firstWriteByteCount); + ProduceStart(appCompleted: false); + + return Task.CompletedTask; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public async Task InitializeResponseAwaited(Task startingTask, int firstWriteByteCount) + { + await startingTask; + + if (_applicationException != null) + { + ThrowResponseAbortedException(); + } + + VerifyAndUpdateWrite(firstWriteByteCount); + ProduceStart(appCompleted: false); + } + + private void ProduceStart(bool appCompleted) + { + if (HasResponseStarted) + { + return; + } + + _requestProcessingStatus = RequestProcessingStatus.ResponseStarted; + + CreateResponseHeader(appCompleted); + } + + protected Task TryProduceInvalidRequestResponse() + { + // If _ioCompleted is set, the connection has already been closed. + if (_requestRejectedException != null && _ioCompleted == 0) + { + return ProduceEnd(); + } + + return Task.CompletedTask; + } + + protected Task ProduceEnd() + { + if (_requestRejectedException != null || _applicationException != null) + { + if (HasResponseStarted) + { + // We can no longer change the response, so we simply close the connection. + _keepAlive = false; + return Task.CompletedTask; + } + + // If the request was rejected, the error state has already been set by SetBadRequestState and + // that should take precedence. + if (_requestRejectedException != null) + { + SetErrorResponseException(_requestRejectedException); + } + else + { + // 500 Internal Server Error + SetErrorResponseHeaders(statusCode: StatusCodes.Status500InternalServerError); + } + } + + if (!HasResponseStarted) + { + return ProduceEndAwaited(); + } + + return WriteSuffix(); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private async Task ProduceEndAwaited() + { + ProduceStart(appCompleted: true); + + // Force flush + await Output.FlushAsync(default(CancellationToken)); + + await WriteSuffix(); + } + + private Task WriteSuffix() + { + // _autoChunk should be checked after we are sure ProduceStart() has been called + // since ProduceStart() may set _autoChunk to true. + if (_autoChunk || _httpVersion == Http.HttpVersion.Http2) + { + return WriteSuffixAwaited(); + } + + if (_keepAlive) + { + Log.ConnectionKeepAlive(ConnectionId); + } + + if (Method == HttpMethod.Head && _responseBytesWritten > 0) + { + Log.ConnectionHeadResponseBodyWrite(ConnectionId, _responseBytesWritten); + } + + return Task.CompletedTask; + } + + private async Task WriteSuffixAwaited() + { + // For the same reason we call CheckLastWrite() in Content-Length responses. + _abortedCts = null; + + await Output.WriteStreamSuffixAsync(default(CancellationToken)); + + if (_keepAlive) + { + Log.ConnectionKeepAlive(ConnectionId); + } + + if (Method == HttpMethod.Head && _responseBytesWritten > 0) + { + Log.ConnectionHeadResponseBodyWrite(ConnectionId, _responseBytesWritten); + } + } + + private void CreateResponseHeader(bool appCompleted) + { + var responseHeaders = HttpResponseHeaders; + var hasConnection = responseHeaders.HasConnection; + var connectionOptions = HttpHeaders.ParseConnection(responseHeaders.HeaderConnection); + var hasTransferEncoding = responseHeaders.HasTransferEncoding; + + if (_keepAlive && hasConnection && (connectionOptions & ConnectionOptions.KeepAlive) != ConnectionOptions.KeepAlive) + { + _keepAlive = false; + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.1 + // If any transfer coding other than + // chunked is applied to a response payload body, the sender MUST either + // apply chunked as the final transfer coding or terminate the message + // by closing the connection. + if (hasTransferEncoding && + HttpHeaders.GetFinalTransferCoding(responseHeaders.HeaderTransferEncoding) != TransferCoding.Chunked) + { + _keepAlive = false; + } + + // Set whether response can have body + _canHaveBody = StatusCanHaveBody(StatusCode) && Method != HttpMethod.Head; + + // Don't set the Content-Length or Transfer-Encoding headers + // automatically for HEAD requests or 204, 205, 304 responses. + if (_canHaveBody) + { + if (!hasTransferEncoding && !responseHeaders.ContentLength.HasValue) + { + if (appCompleted && StatusCode != StatusCodes.Status101SwitchingProtocols) + { + // Since the app has completed and we are only now generating + // the headers we can safely set the Content-Length to 0. + responseHeaders.ContentLength = 0; + } + else + { + // Note for future reference: never change this to set _autoChunk to true on HTTP/1.0 + // connections, even if we were to infer the client supports it because an HTTP/1.0 request + // was received that used chunked encoding. Sending a chunked response to an HTTP/1.0 + // client would break compliance with RFC 7230 (section 3.3.1): + // + // A server MUST NOT send a response containing Transfer-Encoding unless the corresponding + // request indicates HTTP/1.1 (or later). + // + // This also covers HTTP/2, which forbids chunked encoding in RFC 7540 (section 8.1: + // + // The chunked transfer encoding defined in Section 4.1 of [RFC7230] MUST NOT be used in HTTP/2. + if (_httpVersion == Http.HttpVersion.Http11 && StatusCode != StatusCodes.Status101SwitchingProtocols) + { + _autoChunk = true; + responseHeaders.SetRawTransferEncoding("chunked", _bytesTransferEncodingChunked); + } + else + { + _keepAlive = false; + } + } + } + } + else if (hasTransferEncoding) + { + RejectNonBodyTransferEncodingResponse(appCompleted); + } + + responseHeaders.SetReadOnly(); + + if (!hasConnection && _httpVersion != Http.HttpVersion.Http2) + { + if (!_keepAlive) + { + responseHeaders.SetRawConnection("close", _bytesConnectionClose); + } + else if (_httpVersion == Http.HttpVersion.Http10) + { + responseHeaders.SetRawConnection("keep-alive", _bytesConnectionKeepAlive); + } + } + + if (ServerOptions.AddServerHeader && !responseHeaders.HasServer) + { + responseHeaders.SetRawServer(Constants.ServerName, _bytesServer); + } + + if (!responseHeaders.HasDate) + { + var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues(); + responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + } + + Output.WriteResponseHeaders(StatusCode, ReasonPhrase, responseHeaders); + } + + public bool StatusCanHaveBody(int statusCode) + { + // List of status codes taken from Microsoft.Net.Http.Server.Response + return statusCode != StatusCodes.Status204NoContent && + statusCode != StatusCodes.Status205ResetContent && + statusCode != StatusCodes.Status304NotModified; + } + + private void ThrowResponseAlreadyStartedException(string value) + { + throw new InvalidOperationException(CoreStrings.FormatParameterReadOnlyAfterResponseStarted(value)); + } + + private void RejectNonBodyTransferEncodingResponse(bool appCompleted) + { + var ex = new InvalidOperationException(CoreStrings.FormatHeaderNotAllowedOnResponse("Transfer-Encoding", StatusCode)); + if (!appCompleted) + { + // Back out of header creation surface exeception in user code + _requestProcessingStatus = RequestProcessingStatus.AppStarted; + throw ex; + } + else + { + ReportApplicationError(ex); + + // 500 Internal Server Error + SetErrorResponseHeaders(statusCode: StatusCodes.Status500InternalServerError); + } + } + + private void SetErrorResponseException(BadHttpRequestException ex) + { + SetErrorResponseHeaders(ex.StatusCode); + + if (!StringValues.IsNullOrEmpty(ex.AllowedHeader)) + { + HttpResponseHeaders.HeaderAllow = ex.AllowedHeader; + } + } + + private void SetErrorResponseHeaders(int statusCode) + { + Debug.Assert(!HasResponseStarted, $"{nameof(SetErrorResponseHeaders)} called after response had already started."); + + StatusCode = statusCode; + ReasonPhrase = null; + + var responseHeaders = HttpResponseHeaders; + responseHeaders.Reset(); + var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues(); + + responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + + responseHeaders.ContentLength = 0; + + if (ServerOptions.AddServerHeader) + { + responseHeaders.SetRawServer(Constants.ServerName, _bytesServer); + } + } + + public void HandleNonBodyResponseWrite() + { + // Writes to HEAD response are ignored and logged at the end of the request + if (Method != HttpMethod.Head) + { + ThrowWritingToResponseBodyNotSupported(); + } + } + + [StackTraceHidden] + private void ThrowWritingToResponseBodyNotSupported() + { + // Throw Exception for 204, 205, 304 responses. + throw new InvalidOperationException(CoreStrings.FormatWritingToResponseBodyNotSupported(StatusCode)); + } + + [StackTraceHidden] + private void ThrowResponseAbortedException() + { + throw new ObjectDisposedException(CoreStrings.UnhandledApplicationException, _applicationException); + } + + [StackTraceHidden] + public void ThrowRequestTargetRejected(Span target) + => throw GetInvalidRequestTargetException(target); + + [MethodImpl(MethodImplOptions.NoInlining)] + private BadHttpRequestException GetInvalidRequestTargetException(Span target) + => BadHttpRequestException.GetException( + RequestRejectionReason.InvalidRequestTarget, + Log.IsEnabled(LogLevel.Information) + ? target.GetAsciiStringEscaped(Constants.MaxExceptionDetailSize) + : string.Empty); + + public void SetBadRequestState(RequestRejectionReason reason) + { + SetBadRequestState(BadHttpRequestException.GetException(reason)); + } + + public void SetBadRequestState(BadHttpRequestException ex) + { + Log.ConnectionBadRequest(ConnectionId, ex); + + if (!HasResponseStarted) + { + SetErrorResponseException(ex); + } + + _keepAlive = false; + _requestRejectedException = ex; + } + + protected void ReportApplicationError(Exception ex) + { + if (_applicationException == null) + { + _applicationException = ex; + } + else if (_applicationException is AggregateException) + { + _applicationException = new AggregateException(_applicationException, ex).Flatten(); + } + else + { + _applicationException = new AggregateException(_applicationException, ex); + } + + Log.ApplicationError(ConnectionId, TraceIdentifier, ex); + } + + private Pipe CreateRequestBodyPipe() + => new Pipe(new PipeOptions + ( + pool: _context.MemoryPool, + readerScheduler: ServiceContext.Scheduler, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: 1, + resumeWriterThreshold: 1, + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + )); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs new file mode 100644 index 0000000000..9ee5a2a27e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs @@ -0,0 +1,104 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class HttpRequestHeaders : HttpHeaders + { + private static long ParseContentLength(string value) + { + long parsed; + if (!HeaderUtilities.TryParseNonNegativeInt64(value, out parsed)) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value); + } + + return parsed; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void SetValueUnknown(string key, in StringValues value) + { + Unknown[key] = value; + } + + public unsafe void Append(Span name, string value) + { + fixed (byte* namePtr = &MemoryMarshal.GetReference(name)) + { + Append(namePtr, name.Length, value); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe void AppendUnknownHeaders(byte* pKeyBytes, int keyLength, string value) + { + string key = new string('\0', keyLength); + fixed (char* keyBuffer = key) + { + if (!StringUtilities.TryGetAsciiString(pKeyBytes, keyBuffer, keyLength)) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidCharactersInHeaderName); + } + } + + StringValues existing; + Unknown.TryGetValue(key, out existing); + Unknown[key] = AppendValue(existing, value); + } + + public Enumerator GetEnumerator() + { + return new Enumerator(this); + } + + protected override IEnumerator> GetEnumeratorFast() + { + return GetEnumerator(); + } + + public partial struct Enumerator : IEnumerator> + { + private readonly HttpRequestHeaders _collection; + private readonly long _bits; + private int _state; + private KeyValuePair _current; + private readonly bool _hasUnknown; + private Dictionary.Enumerator _unknownEnumerator; + + internal Enumerator(HttpRequestHeaders collection) + { + _collection = collection; + _bits = collection._bits; + _state = 0; + _current = default(KeyValuePair); + _hasUnknown = collection.MaybeUnknown != null; + _unknownEnumerator = _hasUnknown + ? collection.MaybeUnknown.GetEnumerator() + : default(Dictionary.Enumerator); + } + + public KeyValuePair Current => _current; + + object IEnumerator.Current => _current; + + public void Dispose() + { + } + + public void Reset() + { + _state = 0; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs new file mode 100644 index 0000000000..51b7e7b6cd --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestStream.cs @@ -0,0 +1,217 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + internal class HttpRequestStream : ReadOnlyStream + { + private readonly IHttpBodyControlFeature _bodyControl; + private MessageBody _body; + private HttpStreamState _state; + private Exception _error; + + public HttpRequestStream(IHttpBodyControlFeature bodyControl) + { + _bodyControl = bodyControl; + _state = HttpStreamState.Closed; + } + + public override bool CanSeek => false; + + public override long Length + => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() + { + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (!_bodyControl.AllowSynchronousIO) + { + throw new InvalidOperationException(CoreStrings.SynchronousReadsDisallowed); + } + + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = ReadAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override int EndRead(IAsyncResult asyncResult) + { + return ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = ReadAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(task2.Result); + } + }, tcs, cancellationToken); + return tcs.Task; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateState(cancellationToken); + + return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + +#if NETCOREAPP2_1 + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + ValidateState(cancellationToken); + + return ReadAsyncInternal(destination, cancellationToken); + } +#endif + + private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) + { + try + { + return await _body.ReadAsync(buffer, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (destination == null) + { + throw new ArgumentNullException(nameof(destination)); + } + if (bufferSize <= 0) + { + throw new ArgumentException(CoreStrings.PositiveNumberRequired, nameof(bufferSize)); + } + + ValidateState(cancellationToken); + + return CopyToAsyncInternal(destination, cancellationToken); + } + + private async Task CopyToAsyncInternal(Stream destination, CancellationToken cancellationToken) + { + try + { + await _body.CopyToAsync(destination, cancellationToken); + } + catch (ConnectionAbortedException ex) + { + throw new TaskCanceledException("The request was aborted", ex); + } + } + + public void StartAcceptingReads(MessageBody body) + { + // Only start if not aborted + if (_state == HttpStreamState.Closed) + { + _state = HttpStreamState.Open; + _body = body; + } + } + + public void StopAcceptingReads() + { + // Can't use dispose (or close) as can be disposed too early by user code + // As exampled in EngineTests.ZeroContentLengthNotSetAutomaticallyForCertainStatusCodes + _state = HttpStreamState.Closed; + _body = null; + } + + public void Abort(Exception error = null) + { + // We don't want to throw an ODE until the app func actually completes. + // If the request is aborted, we throw a TaskCanceledException instead, + // unless error is not null, in which case we throw it. + if (_state != HttpStreamState.Closed) + { + _state = HttpStreamState.Aborted; + _error = error; + } + } + + private void ValidateState(CancellationToken cancellationToken) + { + switch (_state) + { + case HttpStreamState.Open: + if (cancellationToken.IsCancellationRequested) + { + cancellationToken.ThrowIfCancellationRequested(); + } + break; + case HttpStreamState.Closed: + throw new ObjectDisposedException(nameof(HttpRequestStream)); + case HttpStreamState.Aborted: + if (_error != null) + { + ExceptionDispatchInfo.Capture(_error).Throw(); + } + else + { + throw new TaskCanceledException(); + } + break; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestTargetForm.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestTargetForm.cs new file mode 100644 index 0000000000..0e43670fa3 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestTargetForm.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum HttpRequestTarget + { + Unknown = -1, + // origin-form is the most common + OriginForm, + AbsoluteForm, + AuthorityForm, + AsteriskForm + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseHeaders.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseHeaders.cs new file mode 100644 index 0000000000..1df80f3dc6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseHeaders.cs @@ -0,0 +1,110 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public partial class HttpResponseHeaders : HttpHeaders + { + private static readonly byte[] _CrLf = new[] { (byte)'\r', (byte)'\n' }; + private static readonly byte[] _colonSpace = new[] { (byte)':', (byte)' ' }; + + public Enumerator GetEnumerator() + { + return new Enumerator(this); + } + + protected override IEnumerator> GetEnumeratorFast() + { + return GetEnumerator(); + } + + internal void CopyTo(ref BufferWriter buffer) + { + CopyToFast(ref buffer); + if (MaybeUnknown != null) + { + foreach (var kv in MaybeUnknown) + { + foreach (var value in kv.Value) + { + if (value != null) + { + buffer.Write(_CrLf); + PipelineExtensions.WriteAsciiNoValidation(ref buffer, kv.Key); + buffer.Write(_colonSpace); + PipelineExtensions.WriteAsciiNoValidation(ref buffer, value); + } + } + } + } + } + + private static long ParseContentLength(string value) + { + long parsed; + if (!HeaderUtilities.TryParseNonNegativeInt64(value, out parsed)) + { + ThrowInvalidContentLengthException(value); + } + + return parsed; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void SetValueUnknown(string key, in StringValues value) + { + ValidateHeaderCharacters(key); + Unknown[key] = value; + } + + private static void ThrowInvalidContentLengthException(string value) + { + throw new InvalidOperationException(CoreStrings.FormatInvalidContentLength_InvalidNumber(value)); + } + + public partial struct Enumerator : IEnumerator> + { + private readonly HttpResponseHeaders _collection; + private readonly long _bits; + private int _state; + private KeyValuePair _current; + private readonly bool _hasUnknown; + private Dictionary.Enumerator _unknownEnumerator; + + internal Enumerator(HttpResponseHeaders collection) + { + _collection = collection; + _bits = collection._bits; + _state = 0; + _current = default; + _hasUnknown = collection.MaybeUnknown != null; + _unknownEnumerator = _hasUnknown + ? collection.MaybeUnknown.GetEnumerator() + : default; + } + + public KeyValuePair Current => _current; + + object IEnumerator.Current => _current; + + public void Dispose() + { + } + + public void Reset() + { + _state = 0; + } + } + + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs new file mode 100644 index 0000000000..aefbe6fb19 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpResponseStream.cs @@ -0,0 +1,170 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + internal class HttpResponseStream : WriteOnlyStream + { + private readonly IHttpBodyControlFeature _bodyControl; + private readonly IHttpResponseControl _httpResponseControl; + private HttpStreamState _state; + + public HttpResponseStream(IHttpBodyControlFeature bodyControl, IHttpResponseControl httpResponseControl) + { + _bodyControl = bodyControl; + _httpResponseControl = httpResponseControl; + _state = HttpStreamState.Closed; + } + + public override bool CanSeek => false; + + public override long Length + => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() + { + FlushAsync(default(CancellationToken)).GetAwaiter().GetResult(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + ValidateState(cancellationToken); + + return _httpResponseControl.FlushAsync(cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (!_bodyControl.AllowSynchronousIO) + { + throw new InvalidOperationException(CoreStrings.SynchronousWritesDisallowed); + } + + WriteAsync(buffer, offset, count, default(CancellationToken)).GetAwaiter().GetResult(); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + var task = WriteAsync(buffer, offset, count, default(CancellationToken), state); + if (callback != null) + { + task.ContinueWith(t => callback.Invoke(t)); + } + return task; + } + + public override void EndWrite(IAsyncResult asyncResult) + { + ((Task)asyncResult).GetAwaiter().GetResult(); + } + + private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) + { + var tcs = new TaskCompletionSource(state); + var task = WriteAsync(buffer, offset, count, cancellationToken); + task.ContinueWith((task2, state2) => + { + var tcs2 = (TaskCompletionSource)state2; + if (task2.IsCanceled) + { + tcs2.SetCanceled(); + } + else if (task2.IsFaulted) + { + tcs2.SetException(task2.Exception); + } + else + { + tcs2.SetResult(null); + } + }, tcs, cancellationToken); + return tcs.Task; + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateState(cancellationToken); + + return _httpResponseControl.WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); + } + +#if NETCOREAPP2_1 + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + ValidateState(cancellationToken); + + return new ValueTask(_httpResponseControl.WriteAsync(source, cancellationToken)); + } +#endif + + public void StartAcceptingWrites() + { + // Only start if not aborted + if (_state == HttpStreamState.Closed) + { + _state = HttpStreamState.Open; + } + } + + public void StopAcceptingWrites() + { + // Can't use dispose (or close) as can be disposed too early by user code + // As exampled in EngineTests.ZeroContentLengthNotSetAutomaticallyForCertainStatusCodes + _state = HttpStreamState.Closed; + } + + public void Abort() + { + // We don't want to throw an ODE until the app func actually completes. + if (_state != HttpStreamState.Closed) + { + _state = HttpStreamState.Aborted; + } + } + + private void ValidateState(CancellationToken cancellationToken) + { + switch (_state) + { + case HttpStreamState.Open: + if (cancellationToken.IsCancellationRequested) + { + cancellationToken.ThrowIfCancellationRequested(); + } + break; + case HttpStreamState.Closed: + throw new ObjectDisposedException(nameof(HttpResponseStream), CoreStrings.WritingToResponseBodyAfterResponseCompleted); + case HttpStreamState.Aborted: + if (cancellationToken.IsCancellationRequested) + { + // Aborted state only throws on write if cancellationToken requests it + cancellationToken.ThrowIfCancellationRequested(); + } + break; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpScheme.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpScheme.cs new file mode 100644 index 0000000000..dfd4642f3d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpScheme.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum HttpScheme + { + Unknown = -1, + Http = 0, + Https = 1 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpStreamState.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpStreamState.cs new file mode 100644 index 0000000000..34d5e904f5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpStreamState.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + enum HttpStreamState + { + Open, + Closed, + Aborted + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpUpgradeStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpUpgradeStream.cs new file mode 100644 index 0000000000..d6fedc0518 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpUpgradeStream.cs @@ -0,0 +1,202 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + internal class HttpUpgradeStream : Stream + { + private readonly Stream _requestStream; + private readonly Stream _responseStream; + + public HttpUpgradeStream(Stream requestStream, Stream responseStream) + { + _requestStream = requestStream; + _responseStream = responseStream; + } + + public override bool CanRead + { + get + { + return _requestStream.CanRead; + } + } + + public override bool CanSeek + { + get + { + return _requestStream.CanSeek; + } + } + + public override bool CanTimeout + { + get + { + return _responseStream.CanTimeout || _requestStream.CanTimeout; + } + } + + public override bool CanWrite + { + get + { + return _responseStream.CanWrite; + } + } + + public override long Length + { + get + { + return _requestStream.Length; + } + } + + public override long Position + { + get + { + return _requestStream.Position; + } + set + { + _requestStream.Position = value; + } + } + + public override int ReadTimeout + { + get + { + return _requestStream.ReadTimeout; + } + set + { + _requestStream.ReadTimeout = value; + } + } + + public override int WriteTimeout + { + get + { + return _responseStream.WriteTimeout; + } + set + { + _responseStream.WriteTimeout = value; + } + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _requestStream.Dispose(); + _responseStream.Dispose(); + } + } + + public override void Flush() + { + _responseStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _responseStream.FlushAsync(cancellationToken); + } + + public override void Close() + { + _requestStream.Close(); + _responseStream.Close(); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _requestStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _requestStream.EndRead(asyncResult); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _responseStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _responseStream.EndWrite(asyncResult); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _requestStream.ReadAsync(buffer, offset, count, cancellationToken); + } + +#if NETCOREAPP2_1 + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + return _requestStream.ReadAsync(destination, cancellationToken); + } +#endif + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _requestStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _responseStream.WriteAsync(buffer, offset, count, cancellationToken); + } + +#if NETCOREAPP2_1 + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + return _responseStream.WriteAsync(source, cancellationToken); + } +#endif + + public override long Seek(long offset, SeekOrigin origin) + { + return _requestStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _requestStream.SetLength(value); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _requestStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _requestStream.ReadByte(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _responseStream.Write(buffer, offset, count); + } + + public override void WriteByte(byte value) + { + _responseStream.WriteByte(value); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpVersion.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpVersion.cs new file mode 100644 index 0000000000..832a1c5616 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpVersion.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum HttpVersion + { + Unknown = -1, + Http10 = 0, + Http11 = 1, + Http2 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpHeadersHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpHeadersHandler.cs new file mode 100644 index 0000000000..9a322f0da9 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpHeadersHandler.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpHeadersHandler + { + void OnHeader(Span name, Span value); + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs new file mode 100644 index 0000000000..41dfdbbbec --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpOutputProducer : IDisposable + { + void Abort(ConnectionAbortedException abortReason); + Task WriteAsync(Func callback, T state); + Task FlushAsync(CancellationToken cancellationToken); + Task Write100ContinueAsync(CancellationToken cancellationToken); + void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders); + // The reason this is ReadOnlySpan and not ReadOnlyMemory is because writes are always + // synchronous. Flushing to get back pressure is the only time we truly go async but + // that's after the buffer is copied + Task WriteDataAsync(ReadOnlySpan data, CancellationToken cancellationToken); + Task WriteStreamSuffixAsync(CancellationToken cancellationToken); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpParser.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpParser.cs new file mode 100644 index 0000000000..efd8e9445b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpParser.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpParser where TRequestHandler : IHttpHeadersHandler, IHttpRequestLineHandler + { + bool ParseRequestLine(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined); + + bool ParseHeaders(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined, out int consumedBytes); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpProtocolContext.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpProtocolContext.cs new file mode 100644 index 0000000000..2aa526b867 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpProtocolContext.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Net; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpProtocolContext + { + string ConnectionId { get; set; } + ServiceContext ServiceContext { get; set; } + IFeatureCollection ConnectionFeatures { get; set; } + MemoryPool MemoryPool { get; set; } + IPEndPoint RemoteEndPoint { get; set; } + IPEndPoint LocalEndPoint { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpRequestLineHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpRequestLineHandler.cs new file mode 100644 index 0000000000..ac91138512 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpRequestLineHandler.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpRequestLineHandler + { + void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded); + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs new file mode 100644 index 0000000000..9a42aa6116 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IHttpResponseControl + { + void ProduceContinue(); + Task WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken); + Task FlushAsync(CancellationToken cancellationToken); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs new file mode 100644 index 0000000000..33bd8ebfb5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs @@ -0,0 +1,172 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public abstract class MessageBody + { + private static readonly MessageBody _zeroContentLengthClose = new ForZeroContentLength(keepAlive: false); + private static readonly MessageBody _zeroContentLengthKeepAlive = new ForZeroContentLength(keepAlive: true); + + private readonly HttpProtocol _context; + + private bool _send100Continue = true; + + protected MessageBody(HttpProtocol context) + { + _context = context; + } + + public static MessageBody ZeroContentLengthClose => _zeroContentLengthClose; + + public static MessageBody ZeroContentLengthKeepAlive => _zeroContentLengthKeepAlive; + + public bool RequestKeepAlive { get; protected set; } + + public bool RequestUpgrade { get; protected set; } + + public virtual bool IsEmpty => false; + + protected IKestrelTrace Log => _context.ServiceContext.Log; + + public virtual async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default(CancellationToken)) + { + TryInit(); + + while (true) + { + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + // buffer.Count is int + var actual = (int) Math.Min(readableBuffer.Length, buffer.Length); + var slice = readableBuffer.Slice(0, actual); + consumed = readableBuffer.GetPosition(actual); + slice.CopyTo(buffer.Span); + return actual; + } + + if (result.IsCompleted) + { + return 0; + } + } + finally + { + _context.RequestBodyPipe.Reader.AdvanceTo(consumed); + } + } + } + + public virtual async Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) + { + TryInit(); + + while (true) + { + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + foreach (var memory in readableBuffer) + { + // REVIEW: This *could* be slower if 2 things are true + // - The WriteAsync(ReadOnlyMemory) isn't overridden on the destination + // - We change the Kestrel Memory Pool to not use pinned arrays but instead use native memory +#if NETCOREAPP2_1 + await destination.WriteAsync(memory); +#else + var array = memory.GetArray(); + await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken); +#endif + } + } + + if (result.IsCompleted) + { + return; + } + } + finally + { + _context.RequestBodyPipe.Reader.AdvanceTo(consumed); + } + } + } + + public virtual Task ConsumeAsync() + { + TryInit(); + + return OnConsumeAsync(); + } + + protected abstract Task OnConsumeAsync(); + + public abstract Task StopAsync(); + + protected void TryProduceContinue() + { + if (_send100Continue) + { + _context.HttpResponseControl.ProduceContinue(); + _send100Continue = false; + } + } + + private void TryInit() + { + if (!_context.HasStartedConsumingRequestBody) + { + OnReadStarting(); + _context.HasStartedConsumingRequestBody = true; + OnReadStarted(); + } + } + + protected virtual void OnReadStarting() + { + } + + protected virtual void OnReadStarted() + { + } + + private class ForZeroContentLength : MessageBody + { + public ForZeroContentLength(bool keepAlive) + : base(null) + { + RequestKeepAlive = keepAlive; + } + + public override bool IsEmpty => true; + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default(CancellationToken)) => new ValueTask(0); + + public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) => Task.CompletedTask; + + public override Task ConsumeAsync() => Task.CompletedTask; + + public override Task StopAsync() => Task.CompletedTask; + + protected override Task OnConsumeAsync() => Task.CompletedTask; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/PathNormalizer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/PathNormalizer.cs new file mode 100644 index 0000000000..68cdddf7ce --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/PathNormalizer.cs @@ -0,0 +1,208 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public static class PathNormalizer + { + private const byte ByteSlash = (byte)'/'; + private const byte ByteDot = (byte)'.'; + + // In-place implementation of the algorithm from https://tools.ietf.org/html/rfc3986#section-5.2.4 + public static unsafe int RemoveDotSegments(Span input) + { + fixed (byte* start = &MemoryMarshal.GetReference(input)) + { + var end = start + input.Length; + return RemoveDotSegments(start, end); + } + } + + public static unsafe int RemoveDotSegments(byte* start, byte* end) + { + if (!ContainsDotSegments(start, end)) + { + return (int)(end - start); + } + + var src = start; + var dst = start; + + while (src < end) + { + var ch1 = *src; + Debug.Assert(ch1 == '/', "Path segment must always start with a '/'"); + + byte ch2, ch3, ch4; + + switch (end - src) + { + case 1: + break; + case 2: + ch2 = *(src + 1); + + if (ch2 == ByteDot) + { + // B. if the input buffer begins with a prefix of "/./" or "/.", + // where "." is a complete path segment, then replace that + // prefix with "/" in the input buffer; otherwise, + src += 1; + *src = ByteSlash; + continue; + } + + break; + case 3: + ch2 = *(src + 1); + ch3 = *(src + 2); + + if (ch2 == ByteDot && ch3 == ByteDot) + { + // C. if the input buffer begins with a prefix of "/../" or "/..", + // where ".." is a complete path segment, then replace that + // prefix with "/" in the input buffer and remove the last + // segment and its preceding "/" (if any) from the output + // buffer; otherwise, + src += 2; + *src = ByteSlash; + + if (dst > start) + { + do + { + dst--; + } while (dst > start && *dst != ByteSlash); + } + + continue; + } + else if (ch2 == ByteDot && ch3 == ByteSlash) + { + // B. if the input buffer begins with a prefix of "/./" or "/.", + // where "." is a complete path segment, then replace that + // prefix with "/" in the input buffer; otherwise, + src += 2; + continue; + } + + break; + default: + ch2 = *(src + 1); + ch3 = *(src + 2); + ch4 = *(src + 3); + + if (ch2 == ByteDot && ch3 == ByteDot && ch4 == ByteSlash) + { + // C. if the input buffer begins with a prefix of "/../" or "/..", + // where ".." is a complete path segment, then replace that + // prefix with "/" in the input buffer and remove the last + // segment and its preceding "/" (if any) from the output + // buffer; otherwise, + src += 3; + + if (dst > start) + { + do + { + dst--; + } while (dst > start && *dst != ByteSlash); + } + + continue; + } + else if (ch2 == ByteDot && ch3 == ByteSlash) + { + // B. if the input buffer begins with a prefix of "/./" or "/.", + // where "." is a complete path segment, then replace that + // prefix with "/" in the input buffer; otherwise, + src += 2; + continue; + } + + break; + } + + // E. move the first path segment in the input buffer to the end of + // the output buffer, including the initial "/" character (if + // any) and any subsequent characters up to, but not including, + // the next "/" character or the end of the input buffer. + do + { + *dst++ = ch1; + ch1 = *++src; + } while (src < end && ch1 != ByteSlash); + } + + if (dst == start) + { + *dst++ = ByteSlash; + } + + return (int)(dst - start); + } + + public static unsafe bool ContainsDotSegments(byte* start, byte* end) + { + var src = start; + var dst = start; + + while (src < end) + { + var ch1 = *src; + Debug.Assert(ch1 == '/', "Path segment must always start with a '/'"); + + byte ch2, ch3, ch4; + + switch (end - src) + { + case 1: + break; + case 2: + ch2 = *(src + 1); + + if (ch2 == ByteDot) + { + return true; + } + + break; + case 3: + ch2 = *(src + 1); + ch3 = *(src + 2); + + if ((ch2 == ByteDot && ch3 == ByteDot) || + (ch2 == ByteDot && ch3 == ByteSlash)) + { + return true; + } + + break; + default: + ch2 = *(src + 1); + ch3 = *(src + 2); + ch4 = *(src + 3); + + if ((ch2 == ByteDot && ch3 == ByteDot && ch4 == ByteSlash) || + (ch2 == ByteDot && ch3 == ByteSlash)) + { + return true; + } + + break; + } + + do + { + ch1 = *++src; + } while (src < end && ch1 != ByteSlash); + } + + return false; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/PipelineExtensions.cs b/src/Servers/Kestrel/Core/src/Internal/Http/PipelineExtensions.cs new file mode 100644 index 0000000000..e56c43b23c --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/PipelineExtensions.cs @@ -0,0 +1,273 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public static class PipelineExtensions + { + private const int _maxULongByteLength = 20; + + [ThreadStatic] + private static byte[] _numericBytesScratch; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ReadOnlySpan ToSpan(this ReadOnlySequence buffer) + { + if (buffer.IsSingleSegment) + { + return buffer.First.Span; + } + return buffer.ToArray(); + } + + public static ArraySegment GetArray(this Memory buffer) + { + return ((ReadOnlyMemory)buffer).GetArray(); + } + + public static ArraySegment GetArray(this ReadOnlyMemory memory) + { + if (!MemoryMarshal.TryGetArray(memory, out var result)) + { + throw new InvalidOperationException("Buffer backed by array was expected"); + } + return result; + } + + internal static unsafe void WriteAsciiNoValidation(ref this BufferWriter buffer, string data) + { + if (string.IsNullOrEmpty(data)) + { + return; + } + + var dest = buffer.Span; + var destLength = dest.Length; + var sourceLength = data.Length; + + // Fast path, try copying to the available memory directly + if (sourceLength <= destLength) + { + fixed (char* input = data) + fixed (byte* output = &MemoryMarshal.GetReference(dest)) + { + EncodeAsciiCharsToBytes(input, output, sourceLength); + } + + buffer.Advance(sourceLength); + } + else + { + WriteAsciiMultiWrite(ref buffer, data); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe void WriteNumeric(ref this BufferWriter buffer, ulong number) + { + const byte AsciiDigitStart = (byte)'0'; + + var span = buffer.Span; + var bytesLeftInBlock = span.Length; + + // Fast path, try copying to the available memory directly + var simpleWrite = true; + fixed (byte* output = &MemoryMarshal.GetReference(span)) + { + var start = output; + if (number < 10 && bytesLeftInBlock >= 1) + { + *(start) = (byte)(((uint)number) + AsciiDigitStart); + buffer.Advance(1); + } + else if (number < 100 && bytesLeftInBlock >= 2) + { + var val = (uint)number; + var tens = (byte)((val * 205u) >> 11); // div10, valid to 1028 + + *(start) = (byte)(tens + AsciiDigitStart); + *(start + 1) = (byte)(val - (tens * 10) + AsciiDigitStart); + buffer.Advance(2); + } + else if (number < 1000 && bytesLeftInBlock >= 3) + { + var val = (uint)number; + var digit0 = (byte)((val * 41u) >> 12); // div100, valid to 1098 + var digits01 = (byte)((val * 205u) >> 11); // div10, valid to 1028 + + *(start) = (byte)(digit0 + AsciiDigitStart); + *(start + 1) = (byte)(digits01 - (digit0 * 10) + AsciiDigitStart); + *(start + 2) = (byte)(val - (digits01 * 10) + AsciiDigitStart); + buffer.Advance(3); + } + else + { + simpleWrite = false; + } + } + + if (!simpleWrite) + { + WriteNumericMultiWrite(ref buffer, number); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void WriteNumericMultiWrite(ref this BufferWriter buffer, ulong number) + { + const byte AsciiDigitStart = (byte)'0'; + + var value = number; + var position = _maxULongByteLength; + var byteBuffer = NumericBytesScratch; + do + { + // Consider using Math.DivRem() if available + var quotient = value / 10; + byteBuffer[--position] = (byte)(AsciiDigitStart + (value - quotient * 10)); // 0x30 = '0' + value = quotient; + } + while (value != 0); + + var length = _maxULongByteLength - position; + buffer.Write(new ReadOnlySpan(byteBuffer, position, length)); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe static void WriteAsciiMultiWrite(ref this BufferWriter buffer, string data) + { + var remaining = data.Length; + + fixed (char* input = data) + { + var inputSlice = input; + + while (remaining > 0) + { + var writable = Math.Min(remaining, buffer.Span.Length); + + if (writable == 0) + { + buffer.Ensure(); + continue; + } + + fixed (byte* output = &MemoryMarshal.GetReference(buffer.Span)) + { + EncodeAsciiCharsToBytes(inputSlice, output, writable); + } + + inputSlice += writable; + remaining -= writable; + + buffer.Advance(writable); + } + } + } + + private static unsafe void EncodeAsciiCharsToBytes(char* input, byte* output, int length) + { + // Note: Not BIGENDIAN or check for non-ascii + const int Shift16Shift24 = (1 << 16) | (1 << 24); + const int Shift8Identity = (1 << 8) | (1); + + // Encode as bytes upto the first non-ASCII byte and return count encoded + int i = 0; + // Use Intrinsic switch + if (IntPtr.Size == 8) // 64 bit + { + if (length < 4) goto trailing; + + int unaligned = (int)(((ulong)input) & 0x7) >> 1; + // Unaligned chars + for (; i < unaligned; i++) + { + char ch = *(input + i); + *(output + i) = (byte)ch; // Cast convert + } + + // Aligned + int ulongDoubleCount = (length - i) & ~0x7; + for (; i < ulongDoubleCount; i += 8) + { + ulong inputUlong0 = *(ulong*)(input + i); + ulong inputUlong1 = *(ulong*)(input + i + 4); + // Pack 16 ASCII chars into 16 bytes + *(uint*)(output + i) = + ((uint)((inputUlong0 * Shift16Shift24) >> 24) & 0xffff) | + ((uint)((inputUlong0 * Shift8Identity) >> 24) & 0xffff0000); + *(uint*)(output + i + 4) = + ((uint)((inputUlong1 * Shift16Shift24) >> 24) & 0xffff) | + ((uint)((inputUlong1 * Shift8Identity) >> 24) & 0xffff0000); + } + if (length - 4 > i) + { + ulong inputUlong = *(ulong*)(input + i); + // Pack 8 ASCII chars into 8 bytes + *(uint*)(output + i) = + ((uint)((inputUlong * Shift16Shift24) >> 24) & 0xffff) | + ((uint)((inputUlong * Shift8Identity) >> 24) & 0xffff0000); + i += 4; + } + + trailing: + for (; i < length; i++) + { + char ch = *(input + i); + *(output + i) = (byte)ch; // Cast convert + } + } + else // 32 bit + { + // Unaligned chars + if ((unchecked((int)input) & 0x2) != 0) + { + char ch = *input; + i = 1; + *(output) = (byte)ch; // Cast convert + } + + // Aligned + int uintCount = (length - i) & ~0x3; + for (; i < uintCount; i += 4) + { + uint inputUint0 = *(uint*)(input + i); + uint inputUint1 = *(uint*)(input + i + 2); + // Pack 4 ASCII chars into 4 bytes + *(ushort*)(output + i) = (ushort)(inputUint0 | (inputUint0 >> 8)); + *(ushort*)(output + i + 2) = (ushort)(inputUint1 | (inputUint1 >> 8)); + } + if (length - 1 > i) + { + uint inputUint = *(uint*)(input + i); + // Pack 2 ASCII chars into 2 bytes + *(ushort*)(output + i) = (ushort)(inputUint | (inputUint >> 8)); + i += 2; + } + + if (i < length) + { + char ch = *(input + i); + *(output + i) = (byte)ch; // Cast convert + i = length; + } + } + } + + private static byte[] NumericBytesScratch => _numericBytesScratch ?? CreateNumericBytesScratch(); + + [MethodImpl(MethodImplOptions.NoInlining)] + private static byte[] CreateNumericBytesScratch() + { + var bytes = new byte[_maxULongByteLength]; + _numericBytesScratch = bytes; + return bytes; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/ProduceEndType.cs b/src/Servers/Kestrel/Core/src/Internal/Http/ProduceEndType.cs new file mode 100644 index 0000000000..72107f90e7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/ProduceEndType.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum ProduceEndType + { + SocketShutdown, + SocketDisconnect, + ConnectionKeepAlive, + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/ReasonPhrases.cs b/src/Servers/Kestrel/Core/src/Internal/Http/ReasonPhrases.cs new file mode 100644 index 0000000000..d372113bda --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/ReasonPhrases.cs @@ -0,0 +1,236 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics; +using System.Globalization; +using System.Text; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public static class ReasonPhrases + { + private static readonly byte[] _bytesStatus100 = CreateStatusBytes(StatusCodes.Status100Continue); + private static readonly byte[] _bytesStatus101 = CreateStatusBytes(StatusCodes.Status101SwitchingProtocols); + private static readonly byte[] _bytesStatus102 = CreateStatusBytes(StatusCodes.Status102Processing); + + private static readonly byte[] _bytesStatus200 = CreateStatusBytes(StatusCodes.Status200OK); + private static readonly byte[] _bytesStatus201 = CreateStatusBytes(StatusCodes.Status201Created); + private static readonly byte[] _bytesStatus202 = CreateStatusBytes(StatusCodes.Status202Accepted); + private static readonly byte[] _bytesStatus203 = CreateStatusBytes(StatusCodes.Status203NonAuthoritative); + private static readonly byte[] _bytesStatus204 = CreateStatusBytes(StatusCodes.Status204NoContent); + private static readonly byte[] _bytesStatus205 = CreateStatusBytes(StatusCodes.Status205ResetContent); + private static readonly byte[] _bytesStatus206 = CreateStatusBytes(StatusCodes.Status206PartialContent); + private static readonly byte[] _bytesStatus207 = CreateStatusBytes(StatusCodes.Status207MultiStatus); + private static readonly byte[] _bytesStatus208 = CreateStatusBytes(StatusCodes.Status208AlreadyReported); + private static readonly byte[] _bytesStatus226 = CreateStatusBytes(StatusCodes.Status226IMUsed); + + private static readonly byte[] _bytesStatus300 = CreateStatusBytes(StatusCodes.Status300MultipleChoices); + private static readonly byte[] _bytesStatus301 = CreateStatusBytes(StatusCodes.Status301MovedPermanently); + private static readonly byte[] _bytesStatus302 = CreateStatusBytes(StatusCodes.Status302Found); + private static readonly byte[] _bytesStatus303 = CreateStatusBytes(StatusCodes.Status303SeeOther); + private static readonly byte[] _bytesStatus304 = CreateStatusBytes(StatusCodes.Status304NotModified); + private static readonly byte[] _bytesStatus305 = CreateStatusBytes(StatusCodes.Status305UseProxy); + private static readonly byte[] _bytesStatus306 = CreateStatusBytes(StatusCodes.Status306SwitchProxy); + private static readonly byte[] _bytesStatus307 = CreateStatusBytes(StatusCodes.Status307TemporaryRedirect); + private static readonly byte[] _bytesStatus308 = CreateStatusBytes(StatusCodes.Status308PermanentRedirect); + + private static readonly byte[] _bytesStatus400 = CreateStatusBytes(StatusCodes.Status400BadRequest); + private static readonly byte[] _bytesStatus401 = CreateStatusBytes(StatusCodes.Status401Unauthorized); + private static readonly byte[] _bytesStatus402 = CreateStatusBytes(StatusCodes.Status402PaymentRequired); + private static readonly byte[] _bytesStatus403 = CreateStatusBytes(StatusCodes.Status403Forbidden); + private static readonly byte[] _bytesStatus404 = CreateStatusBytes(StatusCodes.Status404NotFound); + private static readonly byte[] _bytesStatus405 = CreateStatusBytes(StatusCodes.Status405MethodNotAllowed); + private static readonly byte[] _bytesStatus406 = CreateStatusBytes(StatusCodes.Status406NotAcceptable); + private static readonly byte[] _bytesStatus407 = CreateStatusBytes(StatusCodes.Status407ProxyAuthenticationRequired); + private static readonly byte[] _bytesStatus408 = CreateStatusBytes(StatusCodes.Status408RequestTimeout); + private static readonly byte[] _bytesStatus409 = CreateStatusBytes(StatusCodes.Status409Conflict); + private static readonly byte[] _bytesStatus410 = CreateStatusBytes(StatusCodes.Status410Gone); + private static readonly byte[] _bytesStatus411 = CreateStatusBytes(StatusCodes.Status411LengthRequired); + private static readonly byte[] _bytesStatus412 = CreateStatusBytes(StatusCodes.Status412PreconditionFailed); + private static readonly byte[] _bytesStatus413 = CreateStatusBytes(StatusCodes.Status413PayloadTooLarge); + private static readonly byte[] _bytesStatus414 = CreateStatusBytes(StatusCodes.Status414UriTooLong); + private static readonly byte[] _bytesStatus415 = CreateStatusBytes(StatusCodes.Status415UnsupportedMediaType); + private static readonly byte[] _bytesStatus416 = CreateStatusBytes(StatusCodes.Status416RangeNotSatisfiable); + private static readonly byte[] _bytesStatus417 = CreateStatusBytes(StatusCodes.Status417ExpectationFailed); + private static readonly byte[] _bytesStatus418 = CreateStatusBytes(StatusCodes.Status418ImATeapot); + private static readonly byte[] _bytesStatus419 = CreateStatusBytes(StatusCodes.Status419AuthenticationTimeout); + private static readonly byte[] _bytesStatus421 = CreateStatusBytes(StatusCodes.Status421MisdirectedRequest); + private static readonly byte[] _bytesStatus422 = CreateStatusBytes(StatusCodes.Status422UnprocessableEntity); + private static readonly byte[] _bytesStatus423 = CreateStatusBytes(StatusCodes.Status423Locked); + private static readonly byte[] _bytesStatus424 = CreateStatusBytes(StatusCodes.Status424FailedDependency); + private static readonly byte[] _bytesStatus426 = CreateStatusBytes(StatusCodes.Status426UpgradeRequired); + private static readonly byte[] _bytesStatus428 = CreateStatusBytes(StatusCodes.Status428PreconditionRequired); + private static readonly byte[] _bytesStatus429 = CreateStatusBytes(StatusCodes.Status429TooManyRequests); + private static readonly byte[] _bytesStatus431 = CreateStatusBytes(StatusCodes.Status431RequestHeaderFieldsTooLarge); + private static readonly byte[] _bytesStatus451 = CreateStatusBytes(StatusCodes.Status451UnavailableForLegalReasons); + + private static readonly byte[] _bytesStatus500 = CreateStatusBytes(StatusCodes.Status500InternalServerError); + private static readonly byte[] _bytesStatus501 = CreateStatusBytes(StatusCodes.Status501NotImplemented); + private static readonly byte[] _bytesStatus502 = CreateStatusBytes(StatusCodes.Status502BadGateway); + private static readonly byte[] _bytesStatus503 = CreateStatusBytes(StatusCodes.Status503ServiceUnavailable); + private static readonly byte[] _bytesStatus504 = CreateStatusBytes(StatusCodes.Status504GatewayTimeout); + private static readonly byte[] _bytesStatus505 = CreateStatusBytes(StatusCodes.Status505HttpVersionNotsupported); + private static readonly byte[] _bytesStatus506 = CreateStatusBytes(StatusCodes.Status506VariantAlsoNegotiates); + private static readonly byte[] _bytesStatus507 = CreateStatusBytes(StatusCodes.Status507InsufficientStorage); + private static readonly byte[] _bytesStatus508 = CreateStatusBytes(StatusCodes.Status508LoopDetected); + private static readonly byte[] _bytesStatus510 = CreateStatusBytes(StatusCodes.Status510NotExtended); + private static readonly byte[] _bytesStatus511 = CreateStatusBytes(StatusCodes.Status511NetworkAuthenticationRequired); + + private static byte[] CreateStatusBytes(int statusCode) + { + var reasonPhrase = WebUtilities.ReasonPhrases.GetReasonPhrase(statusCode); + Debug.Assert(!string.IsNullOrEmpty(reasonPhrase)); + + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture) + " " + reasonPhrase); + } + + public static byte[] ToStatusBytes(int statusCode, string reasonPhrase = null) + { + if (string.IsNullOrEmpty(reasonPhrase)) + { + switch (statusCode) + { + case StatusCodes.Status100Continue: + return _bytesStatus100; + case StatusCodes.Status101SwitchingProtocols: + return _bytesStatus101; + case StatusCodes.Status102Processing: + return _bytesStatus102; + + case StatusCodes.Status200OK: + return _bytesStatus200; + case StatusCodes.Status201Created: + return _bytesStatus201; + case StatusCodes.Status202Accepted: + return _bytesStatus202; + case StatusCodes.Status203NonAuthoritative: + return _bytesStatus203; + case StatusCodes.Status204NoContent: + return _bytesStatus204; + case StatusCodes.Status205ResetContent: + return _bytesStatus205; + case StatusCodes.Status206PartialContent: + return _bytesStatus206; + case StatusCodes.Status207MultiStatus: + return _bytesStatus207; + case StatusCodes.Status208AlreadyReported: + return _bytesStatus208; + case StatusCodes.Status226IMUsed: + return _bytesStatus226; + + case StatusCodes.Status300MultipleChoices: + return _bytesStatus300; + case StatusCodes.Status301MovedPermanently: + return _bytesStatus301; + case StatusCodes.Status302Found: + return _bytesStatus302; + case StatusCodes.Status303SeeOther: + return _bytesStatus303; + case StatusCodes.Status304NotModified: + return _bytesStatus304; + case StatusCodes.Status305UseProxy: + return _bytesStatus305; + case StatusCodes.Status306SwitchProxy: + return _bytesStatus306; + case StatusCodes.Status307TemporaryRedirect: + return _bytesStatus307; + case StatusCodes.Status308PermanentRedirect: + return _bytesStatus308; + + case StatusCodes.Status400BadRequest: + return _bytesStatus400; + case StatusCodes.Status401Unauthorized: + return _bytesStatus401; + case StatusCodes.Status402PaymentRequired: + return _bytesStatus402; + case StatusCodes.Status403Forbidden: + return _bytesStatus403; + case StatusCodes.Status404NotFound: + return _bytesStatus404; + case StatusCodes.Status405MethodNotAllowed: + return _bytesStatus405; + case StatusCodes.Status406NotAcceptable: + return _bytesStatus406; + case StatusCodes.Status407ProxyAuthenticationRequired: + return _bytesStatus407; + case StatusCodes.Status408RequestTimeout: + return _bytesStatus408; + case StatusCodes.Status409Conflict: + return _bytesStatus409; + case StatusCodes.Status410Gone: + return _bytesStatus410; + case StatusCodes.Status411LengthRequired: + return _bytesStatus411; + case StatusCodes.Status412PreconditionFailed: + return _bytesStatus412; + case StatusCodes.Status413PayloadTooLarge: + return _bytesStatus413; + case StatusCodes.Status414UriTooLong: + return _bytesStatus414; + case StatusCodes.Status415UnsupportedMediaType: + return _bytesStatus415; + case StatusCodes.Status416RangeNotSatisfiable: + return _bytesStatus416; + case StatusCodes.Status417ExpectationFailed: + return _bytesStatus417; + case StatusCodes.Status418ImATeapot: + return _bytesStatus418; + case StatusCodes.Status419AuthenticationTimeout: + return _bytesStatus419; + case StatusCodes.Status421MisdirectedRequest: + return _bytesStatus421; + case StatusCodes.Status422UnprocessableEntity: + return _bytesStatus422; + case StatusCodes.Status423Locked: + return _bytesStatus423; + case StatusCodes.Status424FailedDependency: + return _bytesStatus424; + case StatusCodes.Status426UpgradeRequired: + return _bytesStatus426; + case StatusCodes.Status428PreconditionRequired: + return _bytesStatus428; + case StatusCodes.Status429TooManyRequests: + return _bytesStatus429; + case StatusCodes.Status431RequestHeaderFieldsTooLarge: + return _bytesStatus431; + case StatusCodes.Status451UnavailableForLegalReasons: + return _bytesStatus451; + + case StatusCodes.Status500InternalServerError: + return _bytesStatus500; + case StatusCodes.Status501NotImplemented: + return _bytesStatus501; + case StatusCodes.Status502BadGateway: + return _bytesStatus502; + case StatusCodes.Status503ServiceUnavailable: + return _bytesStatus503; + case StatusCodes.Status504GatewayTimeout: + return _bytesStatus504; + case StatusCodes.Status505HttpVersionNotsupported: + return _bytesStatus505; + case StatusCodes.Status506VariantAlsoNegotiates: + return _bytesStatus506; + case StatusCodes.Status507InsufficientStorage: + return _bytesStatus507; + case StatusCodes.Status508LoopDetected: + return _bytesStatus508; + case StatusCodes.Status510NotExtended: + return _bytesStatus510; + case StatusCodes.Status511NetworkAuthenticationRequired: + return _bytesStatus511; + + default: + var predefinedReasonPhrase = WebUtilities.ReasonPhrases.GetReasonPhrase(statusCode); + // https://tools.ietf.org/html/rfc7230#section-3.1.2 requires trailing whitespace regardless of reason phrase + var formattedStatusCode = statusCode.ToString(CultureInfo.InvariantCulture) + " "; + return string.IsNullOrEmpty(predefinedReasonPhrase) + ? Encoding.ASCII.GetBytes(formattedStatusCode) + : Encoding.ASCII.GetBytes(formattedStatusCode + predefinedReasonPhrase); + + } + } + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture) + " " + reasonPhrase); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs new file mode 100644 index 0000000000..f6e4248047 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum RequestProcessingStatus + { + RequestPending, + ParsingRequestLine, + ParsingHeaders, + AppStarted, + ResponseStarted + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/RequestRejectionReason.cs b/src/Servers/Kestrel/Core/src/Internal/Http/RequestRejectionReason.cs new file mode 100644 index 0000000000..ee27b5cb96 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/RequestRejectionReason.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public enum RequestRejectionReason + { + UnrecognizedHTTPVersion, + InvalidRequestLine, + InvalidRequestHeader, + InvalidRequestHeadersNoCRLF, + MalformedRequestInvalidHeaders, + InvalidContentLength, + MultipleContentLengths, + UnexpectedEndOfRequestContent, + BadChunkSuffix, + BadChunkSizeData, + ChunkedRequestIncomplete, + InvalidRequestTarget, + InvalidCharactersInHeaderName, + RequestLineTooLong, + HeadersExceedMaxTotalSize, + TooManyHeaders, + RequestBodyTooLarge, + RequestHeadersTimeout, + RequestBodyTimeout, + FinalTransferCodingNotChunked, + LengthRequired, + LengthRequiredHttp10, + OptionsMethodRequired, + ConnectMethodRequired, + MissingHostHeader, + MultipleHostHeaders, + InvalidHostHeader, + UpgradeRequestCannotHavePayload, + RequestBodyExceedsContentLength + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/TransferCoding.cs b/src/Servers/Kestrel/Core/src/Internal/Http/TransferCoding.cs new file mode 100644 index 0000000000..39c52ba6aa --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/TransferCoding.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + [Flags] + public enum TransferCoding + { + None, + Chunked, + Other + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/UrlDecoder.cs b/src/Servers/Kestrel/Core/src/Internal/Http/UrlDecoder.cs new file mode 100644 index 0000000000..e5ed2c86fa --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http/UrlDecoder.cs @@ -0,0 +1,411 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.AspNetCore.Connections.Abstractions +{ + internal class UrlDecoder + { + static bool[] IsAllowed = new bool[0x7F + 1]; + + /// + /// Unescape a URL path + /// + /// The byte span represents a UTF8 encoding url path. + /// The byte span where unescaped url path is copied to. + /// The length of the byte sequence of the unescaped url path. + public static int Decode(ReadOnlySpan source, Span destination) + { + if (destination.Length < source.Length) + { + throw new ArgumentException( + "Lenghth of the destination byte span is less then the source.", + nameof(destination)); + } + + // This requires the destination span to be larger or equal to source span + source.CopyTo(destination); + return DecodeInPlace(destination); + } + + /// + /// Unescape a URL path in place. + /// + /// The byte span represents a UTF8 encoding url path. + /// The number of the bytes representing the result. + /// + /// The unescape is done in place, which means after decoding the result is the subset of + /// the input span. + /// + public static int DecodeInPlace(Span buffer) + { + // the slot to read the input + var sourceIndex = 0; + + // the slot to write the unescaped byte + var destinationIndex = 0; + + while (true) + { + if (sourceIndex == buffer.Length) + { + break; + } + + if (buffer[sourceIndex] == '%') + { + var decodeIndex = sourceIndex; + + // If decoding process succeeds, the writer iterator will be moved + // to the next write-ready location. On the other hand if the scanned + // percent-encodings cannot be interpreted as sequence of UTF-8 octets, + // these bytes should be copied to output as is. + // The decodeReader iterator is always moved to the first byte not yet + // be scanned after the process. A failed decoding means the chars + // between the reader and decodeReader can be copied to output untouched. + if (!DecodeCore(ref decodeIndex, ref destinationIndex, buffer)) + { + Copy(sourceIndex, decodeIndex, ref destinationIndex, buffer); + } + + sourceIndex = decodeIndex; + } + else + { + buffer[destinationIndex++] = buffer[sourceIndex++]; + } + } + + return destinationIndex; + } + + /// + /// Unescape the percent-encodings + /// + /// The iterator point to the first % char + /// The place to write to + /// The byte array + private static bool DecodeCore(ref int sourceIndex, ref int destinationIndex, Span buffer) + { + // preserves the original head. if the percent-encodings cannot be interpreted as sequence of UTF-8 octets, + // bytes from this till the last scanned one will be copied to the memory pointed by writer. + var byte1 = UnescapePercentEncoding(ref sourceIndex, buffer); + if (byte1 == -1) + { + return false; + } + + if (byte1 == 0) + { + throw new InvalidOperationException("The path contains null characters."); + } + + if (byte1 <= 0x7F) + { + // first byte < U+007f, it is a single byte ASCII + buffer[destinationIndex++] = (byte)byte1; + return true; + } + + int byte2 = 0, byte3 = 0, byte4 = 0; + + // anticipate more bytes + var currentDecodeBits = 0; + var byteCount = 1; + var expectValueMin = 0; + if ((byte1 & 0xE0) == 0xC0) + { + // 110x xxxx, expect one more byte + currentDecodeBits = byte1 & 0x1F; + byteCount = 2; + expectValueMin = 0x80; + } + else if ((byte1 & 0xF0) == 0xE0) + { + // 1110 xxxx, expect two more bytes + currentDecodeBits = byte1 & 0x0F; + byteCount = 3; + expectValueMin = 0x800; + } + else if ((byte1 & 0xF8) == 0xF0) + { + // 1111 0xxx, expect three more bytes + currentDecodeBits = byte1 & 0x07; + byteCount = 4; + expectValueMin = 0x10000; + } + else + { + // invalid first byte + return false; + } + + var remainingBytes = byteCount - 1; + while (remainingBytes > 0) + { + // read following three chars + if (sourceIndex == buffer.Length) + { + return false; + } + + var nextSourceIndex = sourceIndex; + var nextByte = UnescapePercentEncoding(ref nextSourceIndex, buffer); + if (nextByte == -1) + { + return false; + } + + if ((nextByte & 0xC0) != 0x80) + { + // the follow up byte is not in form of 10xx xxxx + return false; + } + + currentDecodeBits = (currentDecodeBits << 6) | (nextByte & 0x3F); + remainingBytes--; + + if (remainingBytes == 1 && currentDecodeBits >= 0x360 && currentDecodeBits <= 0x37F) + { + // this is going to end up in the range of 0xD800-0xDFFF UTF-16 surrogates that + // are not allowed in UTF-8; + return false; + } + + if (remainingBytes == 2 && currentDecodeBits >= 0x110) + { + // this is going to be out of the upper Unicode bound 0x10FFFF. + return false; + } + + sourceIndex = nextSourceIndex; + if (byteCount - remainingBytes == 2) + { + byte2 = nextByte; + } + else if (byteCount - remainingBytes == 3) + { + byte3 = nextByte; + } + else if (byteCount - remainingBytes == 4) + { + byte4 = nextByte; + } + } + + if (currentDecodeBits < expectValueMin) + { + // overlong encoding (e.g. using 2 bytes to encode something that only needed 1). + return false; + } + + // all bytes are verified, write to the output + // TODO: measure later to determine if the performance of following logic can be improved + // the idea is to combine the bytes into short/int and write to span directly to avoid + // range check cost + if (byteCount > 0) + { + buffer[destinationIndex++] = (byte)byte1; + } + if (byteCount > 1) + { + buffer[destinationIndex++] = (byte)byte2; + } + if (byteCount > 2) + { + buffer[destinationIndex++] = (byte)byte3; + } + if (byteCount > 3) + { + buffer[destinationIndex++] = (byte)byte4; + } + + return true; + } + + private static void Copy(int begin, int end, ref int writer, Span buffer) + { + while (begin != end) + { + buffer[writer++] = buffer[begin++]; + } + } + + /// + /// Read the percent-encoding and try unescape it. + /// + /// The operation first peek at the character the + /// iterator points at. If it is % the is then + /// moved on to scan the following to characters. If the two following + /// characters are hexadecimal literals they will be unescaped and the + /// value will be returned. + /// + /// If the first character is not % the iterator + /// will be removed beyond the location of % and -1 will be returned. + /// + /// If the following two characters can't be successfully unescaped the + /// iterator will be move behind the % and -1 + /// will be returned. + /// + /// The value to read + /// The byte array + /// The unescaped byte if success. Otherwise return -1. + private static int UnescapePercentEncoding(ref int scan, Span buffer) + { + if (buffer[scan++] != '%') + { + return -1; + } + + var probe = scan; + + var value1 = ReadHex(ref probe, buffer); + if (value1 == -1) + { + return -1; + } + + var value2 = ReadHex(ref probe, buffer); + if (value2 == -1) + { + return -1; + } + + if (SkipUnescape(value1, value2)) + { + return -1; + } + + scan = probe; + return (value1 << 4) + value2; + } + + + /// + /// Read the next char and convert it into hexadecimal value. + /// + /// The index will be moved to the next + /// byte no matter no matter whether the operation successes. + /// + /// The index of the byte in the buffer to read + /// The byte span from which the hex to be read + /// The hexadecimal value if successes, otherwise -1. + private static int ReadHex(ref int scan, Span buffer) + { + if (scan == buffer.Length) + { + return -1; + } + + var value = buffer[scan++]; + var isHead = ((value >= '0') && (value <= '9')) || + ((value >= 'A') && (value <= 'F')) || + ((value >= 'a') && (value <= 'f')); + + if (!isHead) + { + return -1; + } + + if (value <= '9') + { + return value - '0'; + } + else if (value <= 'F') + { + return (value - 'A') + 10; + } + else // a - f + { + return (value - 'a') + 10; + } + } + + private static bool SkipUnescape(int value1, int value2) + { + // skip %2F - '/' + if (value1 == 2 && value2 == 15) + { + return true; + } + + return false; + } + + static UrlDecoder() + { + // Unreserved + IsAllowed['A'] = true; + IsAllowed['B'] = true; + IsAllowed['C'] = true; + IsAllowed['D'] = true; + IsAllowed['E'] = true; + IsAllowed['F'] = true; + IsAllowed['G'] = true; + IsAllowed['H'] = true; + IsAllowed['I'] = true; + IsAllowed['J'] = true; + IsAllowed['K'] = true; + IsAllowed['L'] = true; + IsAllowed['M'] = true; + IsAllowed['N'] = true; + IsAllowed['O'] = true; + IsAllowed['P'] = true; + IsAllowed['Q'] = true; + IsAllowed['R'] = true; + IsAllowed['S'] = true; + IsAllowed['T'] = true; + IsAllowed['U'] = true; + IsAllowed['V'] = true; + IsAllowed['W'] = true; + IsAllowed['X'] = true; + IsAllowed['Y'] = true; + IsAllowed['Z'] = true; + + IsAllowed['a'] = true; + IsAllowed['b'] = true; + IsAllowed['c'] = true; + IsAllowed['d'] = true; + IsAllowed['e'] = true; + IsAllowed['f'] = true; + IsAllowed['g'] = true; + IsAllowed['h'] = true; + IsAllowed['i'] = true; + IsAllowed['j'] = true; + IsAllowed['k'] = true; + IsAllowed['l'] = true; + IsAllowed['m'] = true; + IsAllowed['n'] = true; + IsAllowed['o'] = true; + IsAllowed['p'] = true; + IsAllowed['q'] = true; + IsAllowed['r'] = true; + IsAllowed['s'] = true; + IsAllowed['t'] = true; + IsAllowed['u'] = true; + IsAllowed['v'] = true; + IsAllowed['w'] = true; + IsAllowed['x'] = true; + IsAllowed['y'] = true; + IsAllowed['z'] = true; + + IsAllowed['0'] = true; + IsAllowed['1'] = true; + IsAllowed['2'] = true; + IsAllowed['3'] = true; + IsAllowed['4'] = true; + IsAllowed['5'] = true; + IsAllowed['6'] = true; + IsAllowed['7'] = true; + IsAllowed['8'] = true; + IsAllowed['9'] = true; + + IsAllowed['-'] = true; + IsAllowed['_'] = true; + IsAllowed['.'] = true; + IsAllowed['~'] = true; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/DynamicTable.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/DynamicTable.cs new file mode 100644 index 0000000000..6a13e49b82 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/DynamicTable.cs @@ -0,0 +1,94 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class DynamicTable + { + private HeaderField[] _buffer; + private int _maxSize; + private int _size; + private int _count; + private int _insertIndex; + private int _removeIndex; + + public DynamicTable(int maxSize) + { + _buffer = new HeaderField[maxSize / HeaderField.RfcOverhead]; + _maxSize = maxSize; + } + + public int Count => _count; + + public int Size => _size; + + public int MaxSize => _maxSize; + + public HeaderField this[int index] + { + get + { + if (index >= _count) + { + throw new IndexOutOfRangeException(); + } + + return _buffer[_insertIndex == 0 ? _buffer.Length - 1 : _insertIndex - index - 1]; + } + } + + public void Insert(Span name, Span value) + { + var entryLength = HeaderField.GetLength(name.Length, value.Length); + EnsureAvailable(entryLength); + + if (entryLength > _maxSize) + { + // http://httpwg.org/specs/rfc7541.html#rfc.section.4.4 + // It is not an error to attempt to add an entry that is larger than the maximum size; + // an attempt to add an entry larger than the maximum size causes the table to be emptied + // of all existing entries and results in an empty table. + return; + } + + var entry = new HeaderField(name, value); + _buffer[_insertIndex] = entry; + _insertIndex = (_insertIndex + 1) % _buffer.Length; + _size += entry.Length; + _count++; + } + + public void Resize(int maxSize) + { + if (maxSize > _maxSize) + { + var newBuffer = new HeaderField[maxSize / HeaderField.RfcOverhead]; + + for (var i = 0; i < Count; i++) + { + newBuffer[i] = _buffer[i]; + } + + _buffer = newBuffer; + _maxSize = maxSize; + } + else + { + _maxSize = maxSize; + EnsureAvailable(0); + } + } + + private void EnsureAvailable(int available) + { + while (_count > 0 && _maxSize - _size < available) + { + _size -= _buffer[_removeIndex].Length; + _count--; + _removeIndex = (_removeIndex + 1) % _buffer.Length; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecoder.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecoder.cs new file mode 100644 index 0000000000..84f91ca896 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecoder.cs @@ -0,0 +1,406 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HPackDecoder + { + private enum State + { + Ready, + HeaderFieldIndex, + HeaderNameIndex, + HeaderNameLength, + HeaderNameLengthContinue, + HeaderName, + HeaderValueLength, + HeaderValueLengthContinue, + HeaderValue, + DynamicTableSizeUpdate + } + + // TODO: add new configurable limit + public const int MaxStringOctets = 4096; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.6.1 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 1 | Index (7+) | + // +---+---------------------------+ + private const byte IndexedHeaderFieldMask = 0x80; + private const byte IndexedHeaderFieldRepresentation = 0x80; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.6.2.1 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 1 | Index (6+) | + // +---+---+-----------------------+ + private const byte LiteralHeaderFieldWithIncrementalIndexingMask = 0xc0; + private const byte LiteralHeaderFieldWithIncrementalIndexingRepresentation = 0x40; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.6.2.2 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 0 | Index (4+) | + // +---+---+-----------------------+ + private const byte LiteralHeaderFieldWithoutIndexingMask = 0xf0; + private const byte LiteralHeaderFieldWithoutIndexingRepresentation = 0x00; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.6.2.3 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 1 | Index (4+) | + // +---+---+-----------------------+ + private const byte LiteralHeaderFieldNeverIndexedMask = 0xf0; + private const byte LiteralHeaderFieldNeverIndexedRepresentation = 0x10; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.6.3 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 1 | Max size (5+) | + // +---+---------------------------+ + private const byte DynamicTableSizeUpdateMask = 0xe0; + private const byte DynamicTableSizeUpdateRepresentation = 0x20; + + // http://httpwg.org/specs/rfc7541.html#rfc.section.5.2 + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | H | String Length (7+) | + // +---+---------------------------+ + private const byte HuffmanMask = 0x80; + + private const int IndexedHeaderFieldPrefix = 7; + private const int LiteralHeaderFieldWithIncrementalIndexingPrefix = 6; + private const int LiteralHeaderFieldWithoutIndexingPrefix = 4; + private const int LiteralHeaderFieldNeverIndexedPrefix = 4; + private const int DynamicTableSizeUpdatePrefix = 5; + private const int StringLengthPrefix = 7; + + private readonly int _maxDynamicTableSize; + private readonly DynamicTable _dynamicTable; + private readonly IntegerDecoder _integerDecoder = new IntegerDecoder(); + private readonly byte[] _stringOctets = new byte[MaxStringOctets]; + private readonly byte[] _headerNameOctets = new byte[MaxStringOctets]; + private readonly byte[] _headerValueOctets = new byte[MaxStringOctets]; + + private State _state = State.Ready; + private byte[] _headerName; + private int _stringIndex; + private int _stringLength; + private int _headerNameLength; + private int _headerValueLength; + private bool _index; + private bool _huffman; + + public HPackDecoder(int maxDynamicTableSize) + : this(maxDynamicTableSize, new DynamicTable(maxDynamicTableSize)) + { + _maxDynamicTableSize = maxDynamicTableSize; + } + + // For testing. + internal HPackDecoder(int maxDynamicTableSize, DynamicTable dynamicTable) + { + _maxDynamicTableSize = maxDynamicTableSize; + _dynamicTable = dynamicTable; + } + + public void Decode(Span data, bool endHeaders, IHttpHeadersHandler handler) + { + for (var i = 0; i < data.Length; i++) + { + OnByte(data[i], handler); + } + + if (endHeaders && _state != State.Ready) + { + throw new HPackDecodingException(CoreStrings.HPackErrorIncompleteHeaderBlock); + } + } + + public void OnByte(byte b, IHttpHeadersHandler handler) + { + switch (_state) + { + case State.Ready: + if ((b & IndexedHeaderFieldMask) == IndexedHeaderFieldRepresentation) + { + var val = b & ~IndexedHeaderFieldMask; + + if (_integerDecoder.BeginDecode((byte)val, IndexedHeaderFieldPrefix)) + { + OnIndexedHeaderField(_integerDecoder.Value, handler); + } + else + { + _state = State.HeaderFieldIndex; + } + } + else if ((b & LiteralHeaderFieldWithIncrementalIndexingMask) == LiteralHeaderFieldWithIncrementalIndexingRepresentation) + { + _index = true; + var val = b & ~LiteralHeaderFieldWithIncrementalIndexingMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldWithIncrementalIndexingPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & LiteralHeaderFieldWithoutIndexingMask) == LiteralHeaderFieldWithoutIndexingRepresentation) + { + _index = false; + var val = b & ~LiteralHeaderFieldWithoutIndexingMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldWithoutIndexingPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & LiteralHeaderFieldNeverIndexedMask) == LiteralHeaderFieldNeverIndexedRepresentation) + { + _index = false; + var val = b & ~LiteralHeaderFieldNeverIndexedMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldNeverIndexedPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & DynamicTableSizeUpdateMask) == DynamicTableSizeUpdateRepresentation) + { + if (_integerDecoder.BeginDecode((byte)(b & ~DynamicTableSizeUpdateMask), DynamicTableSizeUpdatePrefix)) + { + // TODO: validate that it's less than what's defined via SETTINGS + _dynamicTable.Resize(_integerDecoder.Value); + } + else + { + _state = State.DynamicTableSizeUpdate; + } + } + else + { + // Can't happen + throw new HPackDecodingException($"Byte value {b} does not encode a valid header field representation."); + } + + break; + case State.HeaderFieldIndex: + if (_integerDecoder.Decode(b)) + { + OnIndexedHeaderField(_integerDecoder.Value, handler); + } + + break; + case State.HeaderNameIndex: + if (_integerDecoder.Decode(b)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + + break; + case State.HeaderNameLength: + _huffman = (b & HuffmanMask) != 0; + + if (_integerDecoder.BeginDecode((byte)(b & ~HuffmanMask), StringLengthPrefix)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderName); + } + else + { + _state = State.HeaderNameLengthContinue; + } + + break; + case State.HeaderNameLengthContinue: + if (_integerDecoder.Decode(b)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderName); + } + + break; + case State.HeaderName: + _stringOctets[_stringIndex++] = b; + + if (_stringIndex == _stringLength) + { + OnString(nextState: State.HeaderValueLength); + } + + break; + case State.HeaderValueLength: + _huffman = (b & HuffmanMask) != 0; + + if (_integerDecoder.BeginDecode((byte)(b & ~HuffmanMask), StringLengthPrefix)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderValue); + if (_integerDecoder.Value == 0) + { + ProcessHeaderValue(handler); + } + } + else + { + _state = State.HeaderValueLengthContinue; + } + + break; + case State.HeaderValueLengthContinue: + if (_integerDecoder.Decode(b)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderValue); + if (_integerDecoder.Value == 0) + { + ProcessHeaderValue(handler); + } + } + + break; + case State.HeaderValue: + _stringOctets[_stringIndex++] = b; + + if (_stringIndex == _stringLength) + { + ProcessHeaderValue(handler); + } + + break; + case State.DynamicTableSizeUpdate: + if (_integerDecoder.Decode(b)) + { + if (_integerDecoder.Value > _maxDynamicTableSize) + { + throw new HPackDecodingException( + CoreStrings.FormatHPackErrorDynamicTableSizeUpdateTooLarge(_integerDecoder.Value, _maxDynamicTableSize)); + } + + _dynamicTable.Resize(_integerDecoder.Value); + _state = State.Ready; + } + + break; + default: + // Can't happen + throw new HPackDecodingException("The HPACK decoder reached an invalid state."); + } + } + + private void ProcessHeaderValue(IHttpHeadersHandler handler) + { + OnString(nextState: State.Ready); + + var headerNameSpan = new Span(_headerName, 0, _headerNameLength); + var headerValueSpan = new Span(_headerValueOctets, 0, _headerValueLength); + + handler.OnHeader(headerNameSpan, headerValueSpan); + + if (_index) + { + _dynamicTable.Insert(headerNameSpan, headerValueSpan); + } + } + + private void OnIndexedHeaderField(int index, IHttpHeadersHandler handler) + { + var header = GetHeader(index); + handler.OnHeader(new Span(header.Name), new Span(header.Value)); + _state = State.Ready; + } + + private void OnIndexedHeaderName(int index) + { + var header = GetHeader(index); + _headerName = header.Name; + _headerNameLength = header.Name.Length; + _state = State.HeaderValueLength; + } + + private void OnStringLength(int length, State nextState) + { + if (length > _stringOctets.Length) + { + throw new HPackDecodingException(CoreStrings.FormatHPackStringLengthTooLarge(length, _stringOctets.Length)); + } + + _stringLength = length; + _stringIndex = 0; + _state = nextState; + } + + private void OnString(State nextState) + { + int Decode(byte[] dst) + { + if (_huffman) + { + return Huffman.Decode(_stringOctets, 0, _stringLength, dst); + } + else + { + Buffer.BlockCopy(_stringOctets, 0, dst, 0, _stringLength); + return _stringLength; + } + } + + try + { + if (_state == State.HeaderName) + { + _headerName = _headerNameOctets; + _headerNameLength = Decode(_headerNameOctets); + } + else + { + _headerValueLength = Decode(_headerValueOctets); + } + } + catch (HuffmanDecodingException ex) + { + throw new HPackDecodingException(CoreStrings.HPackHuffmanError, ex); + } + + _state = nextState; + } + + private HeaderField GetHeader(int index) + { + try + { + return index <= StaticTable.Instance.Count + ? StaticTable.Instance[index - 1] + : _dynamicTable[index - StaticTable.Instance.Count - 1]; + } + catch (IndexOutOfRangeException ex) + { + throw new HPackDecodingException(CoreStrings.FormatHPackErrorIndexOutOfRange(index), ex); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecodingException.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecodingException.cs new file mode 100644 index 0000000000..7ae0ddddf5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackDecodingException.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HPackDecodingException : Exception + { + public HPackDecodingException(string message) + : base(message) + { + } + public HPackDecodingException(string message, Exception innerException) + : base(message, innerException) + { + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackEncoder.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackEncoder.cs new file mode 100644 index 0000000000..0c92961acf --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HPackEncoder.cs @@ -0,0 +1,151 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HPackEncoder + { + private IEnumerator> _enumerator; + + public bool BeginEncode(IEnumerable> headers, Span buffer, out int length) + { + _enumerator = headers.GetEnumerator(); + _enumerator.MoveNext(); + + return Encode(buffer, out length); + } + + public bool BeginEncode(int statusCode, IEnumerable> headers, Span buffer, out int length) + { + _enumerator = headers.GetEnumerator(); + _enumerator.MoveNext(); + + var statusCodeLength = EncodeStatusCode(statusCode, buffer); + var done = Encode(buffer.Slice(statusCodeLength), out var headersLength); + length = statusCodeLength + headersLength; + + return done; + } + + public bool Encode(Span buffer, out int length) + { + length = 0; + + do + { + if (!EncodeHeader(_enumerator.Current.Key, _enumerator.Current.Value, buffer.Slice(length), out var headerLength)) + { + return false; + } + + length += headerLength; + } while (_enumerator.MoveNext()); + + return true; + } + + private int EncodeStatusCode(int statusCode, Span buffer) + { + switch (statusCode) + { + case 200: + case 204: + case 206: + case 304: + case 400: + case 404: + case 500: + buffer[0] = (byte)(0x80 | StaticTable.Instance.StatusIndex[statusCode]); + return 1; + default: + // Send as Literal Header Field Without Indexing - Indexed Name + buffer[0] = 0x08; + + var statusBytes = StatusCodes.ToStatusBytes(statusCode); + buffer[1] = (byte)statusBytes.Length; + ((Span)statusBytes).CopyTo(buffer.Slice(2)); + + return 2 + statusBytes.Length; + } + } + + private bool EncodeHeader(string name, string value, Span buffer, out int length) + { + var i = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + buffer[i++] = 0; + + if (i == buffer.Length) + { + return false; + } + + if (!EncodeString(name, buffer.Slice(i), out var nameLength, lowercase: true)) + { + return false; + } + + i += nameLength; + + if (i >= buffer.Length) + { + return false; + } + + if (!EncodeString(value, buffer.Slice(i), out var valueLength, lowercase: false)) + { + return false; + } + + i += valueLength; + + length = i; + return true; + } + + private bool EncodeString(string s, Span buffer, out int length, bool lowercase) + { + const int toLowerMask = 0x20; + + var i = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + buffer[0] = 0; + + if (!IntegerEncoder.Encode(s.Length, 7, buffer, out var nameLength)) + { + return false; + } + + i += nameLength; + + // TODO: use huffman encoding + for (var j = 0; j < s.Length; j++) + { + if (i >= buffer.Length) + { + return false; + } + + buffer[i++] = (byte)(s[j] | (lowercase ? toLowerMask : 0)); + } + + length = i; + return true; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HeaderField.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HeaderField.cs new file mode 100644 index 0000000000..73eb4d726e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HeaderField.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public struct HeaderField + { + // http://httpwg.org/specs/rfc7541.html#rfc.section.4.1 + public const int RfcOverhead = 32; + + public HeaderField(Span name, Span value) + { + Name = new byte[name.Length]; + name.CopyTo(Name); + + Value = new byte[value.Length]; + value.CopyTo(Value); + } + + public byte[] Name { get; } + + public byte[] Value { get; } + + public int Length => GetLength(Name.Length, Value.Length); + + public static int GetLength(int nameLength, int valueLenth) => nameLength + valueLenth + 32; + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/Huffman.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/Huffman.cs new file mode 100644 index 0000000000..f0d489c952 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/Huffman.cs @@ -0,0 +1,426 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class Huffman + { + // TODO: this can be constructed from _decodingTable + private static readonly (uint code, int bitLength)[] _encodingTable = new (uint code, int bitLength)[] + { + (0b11111111_11000000_00000000_00000000, 13), + (0b11111111_11111111_10110000_00000000, 23), + (0b11111111_11111111_11111110_00100000, 28), + (0b11111111_11111111_11111110_00110000, 28), + (0b11111111_11111111_11111110_01000000, 28), + (0b11111111_11111111_11111110_01010000, 28), + (0b11111111_11111111_11111110_01100000, 28), + (0b11111111_11111111_11111110_01110000, 28), + (0b11111111_11111111_11111110_10000000, 28), + (0b11111111_11111111_11101010_00000000, 24), + (0b11111111_11111111_11111111_11110000, 30), + (0b11111111_11111111_11111110_10010000, 28), + (0b11111111_11111111_11111110_10100000, 28), + (0b11111111_11111111_11111111_11110100, 30), + (0b11111111_11111111_11111110_10110000, 28), + (0b11111111_11111111_11111110_11000000, 28), + (0b11111111_11111111_11111110_11010000, 28), + (0b11111111_11111111_11111110_11100000, 28), + (0b11111111_11111111_11111110_11110000, 28), + (0b11111111_11111111_11111111_00000000, 28), + (0b11111111_11111111_11111111_00010000, 28), + (0b11111111_11111111_11111111_00100000, 28), + (0b11111111_11111111_11111111_11111000, 30), + (0b11111111_11111111_11111111_00110000, 28), + (0b11111111_11111111_11111111_01000000, 28), + (0b11111111_11111111_11111111_01010000, 28), + (0b11111111_11111111_11111111_01100000, 28), + (0b11111111_11111111_11111111_01110000, 28), + (0b11111111_11111111_11111111_10000000, 28), + (0b11111111_11111111_11111111_10010000, 28), + (0b11111111_11111111_11111111_10100000, 28), + (0b11111111_11111111_11111111_10110000, 28), + (0b01010000_00000000_00000000_00000000, 6), + (0b11111110_00000000_00000000_00000000, 10), + (0b11111110_01000000_00000000_00000000, 10), + (0b11111111_10100000_00000000_00000000, 12), + (0b11111111_11001000_00000000_00000000, 13), + (0b01010100_00000000_00000000_00000000, 6), + (0b11111000_00000000_00000000_00000000, 8), + (0b11111111_01000000_00000000_00000000, 11), + (0b11111110_10000000_00000000_00000000, 10), + (0b11111110_11000000_00000000_00000000, 10), + (0b11111001_00000000_00000000_00000000, 8), + (0b11111111_01100000_00000000_00000000, 11), + (0b11111010_00000000_00000000_00000000, 8), + (0b01011000_00000000_00000000_00000000, 6), + (0b01011100_00000000_00000000_00000000, 6), + (0b01100000_00000000_00000000_00000000, 6), + (0b00000000_00000000_00000000_00000000, 5), + (0b00001000_00000000_00000000_00000000, 5), + (0b00010000_00000000_00000000_00000000, 5), + (0b01100100_00000000_00000000_00000000, 6), + (0b01101000_00000000_00000000_00000000, 6), + (0b01101100_00000000_00000000_00000000, 6), + (0b01110000_00000000_00000000_00000000, 6), + (0b01110100_00000000_00000000_00000000, 6), + (0b01111000_00000000_00000000_00000000, 6), + (0b01111100_00000000_00000000_00000000, 6), + (0b10111000_00000000_00000000_00000000, 7), + (0b11111011_00000000_00000000_00000000, 8), + (0b11111111_11111000_00000000_00000000, 15), + (0b10000000_00000000_00000000_00000000, 6), + (0b11111111_10110000_00000000_00000000, 12), + (0b11111111_00000000_00000000_00000000, 10), + (0b11111111_11010000_00000000_00000000, 13), + (0b10000100_00000000_00000000_00000000, 6), + (0b10111010_00000000_00000000_00000000, 7), + (0b10111100_00000000_00000000_00000000, 7), + (0b10111110_00000000_00000000_00000000, 7), + (0b11000000_00000000_00000000_00000000, 7), + (0b11000010_00000000_00000000_00000000, 7), + (0b11000100_00000000_00000000_00000000, 7), + (0b11000110_00000000_00000000_00000000, 7), + (0b11001000_00000000_00000000_00000000, 7), + (0b11001010_00000000_00000000_00000000, 7), + (0b11001100_00000000_00000000_00000000, 7), + (0b11001110_00000000_00000000_00000000, 7), + (0b11010000_00000000_00000000_00000000, 7), + (0b11010010_00000000_00000000_00000000, 7), + (0b11010100_00000000_00000000_00000000, 7), + (0b11010110_00000000_00000000_00000000, 7), + (0b11011000_00000000_00000000_00000000, 7), + (0b11011010_00000000_00000000_00000000, 7), + (0b11011100_00000000_00000000_00000000, 7), + (0b11011110_00000000_00000000_00000000, 7), + (0b11100000_00000000_00000000_00000000, 7), + (0b11100010_00000000_00000000_00000000, 7), + (0b11100100_00000000_00000000_00000000, 7), + (0b11111100_00000000_00000000_00000000, 8), + (0b11100110_00000000_00000000_00000000, 7), + (0b11111101_00000000_00000000_00000000, 8), + (0b11111111_11011000_00000000_00000000, 13), + (0b11111111_11111110_00000000_00000000, 19), + (0b11111111_11100000_00000000_00000000, 13), + (0b11111111_11110000_00000000_00000000, 14), + (0b10001000_00000000_00000000_00000000, 6), + (0b11111111_11111010_00000000_00000000, 15), + (0b00011000_00000000_00000000_00000000, 5), + (0b10001100_00000000_00000000_00000000, 6), + (0b00100000_00000000_00000000_00000000, 5), + (0b10010000_00000000_00000000_00000000, 6), + (0b00101000_00000000_00000000_00000000, 5), + (0b10010100_00000000_00000000_00000000, 6), + (0b10011000_00000000_00000000_00000000, 6), + (0b10011100_00000000_00000000_00000000, 6), + (0b00110000_00000000_00000000_00000000, 5), + (0b11101000_00000000_00000000_00000000, 7), + (0b11101010_00000000_00000000_00000000, 7), + (0b10100000_00000000_00000000_00000000, 6), + (0b10100100_00000000_00000000_00000000, 6), + (0b10101000_00000000_00000000_00000000, 6), + (0b00111000_00000000_00000000_00000000, 5), + (0b10101100_00000000_00000000_00000000, 6), + (0b11101100_00000000_00000000_00000000, 7), + (0b10110000_00000000_00000000_00000000, 6), + (0b01000000_00000000_00000000_00000000, 5), + (0b01001000_00000000_00000000_00000000, 5), + (0b10110100_00000000_00000000_00000000, 6), + (0b11101110_00000000_00000000_00000000, 7), + (0b11110000_00000000_00000000_00000000, 7), + (0b11110010_00000000_00000000_00000000, 7), + (0b11110100_00000000_00000000_00000000, 7), + (0b11110110_00000000_00000000_00000000, 7), + (0b11111111_11111100_00000000_00000000, 15), + (0b11111111_10000000_00000000_00000000, 11), + (0b11111111_11110100_00000000_00000000, 14), + (0b11111111_11101000_00000000_00000000, 13), + (0b11111111_11111111_11111111_11000000, 28), + (0b11111111_11111110_01100000_00000000, 20), + (0b11111111_11111111_01001000_00000000, 22), + (0b11111111_11111110_01110000_00000000, 20), + (0b11111111_11111110_10000000_00000000, 20), + (0b11111111_11111111_01001100_00000000, 22), + (0b11111111_11111111_01010000_00000000, 22), + (0b11111111_11111111_01010100_00000000, 22), + (0b11111111_11111111_10110010_00000000, 23), + (0b11111111_11111111_01011000_00000000, 22), + (0b11111111_11111111_10110100_00000000, 23), + (0b11111111_11111111_10110110_00000000, 23), + (0b11111111_11111111_10111000_00000000, 23), + (0b11111111_11111111_10111010_00000000, 23), + (0b11111111_11111111_10111100_00000000, 23), + (0b11111111_11111111_11101011_00000000, 24), + (0b11111111_11111111_10111110_00000000, 23), + (0b11111111_11111111_11101100_00000000, 24), + (0b11111111_11111111_11101101_00000000, 24), + (0b11111111_11111111_01011100_00000000, 22), + (0b11111111_11111111_11000000_00000000, 23), + (0b11111111_11111111_11101110_00000000, 24), + (0b11111111_11111111_11000010_00000000, 23), + (0b11111111_11111111_11000100_00000000, 23), + (0b11111111_11111111_11000110_00000000, 23), + (0b11111111_11111111_11001000_00000000, 23), + (0b11111111_11111110_11100000_00000000, 21), + (0b11111111_11111111_01100000_00000000, 22), + (0b11111111_11111111_11001010_00000000, 23), + (0b11111111_11111111_01100100_00000000, 22), + (0b11111111_11111111_11001100_00000000, 23), + (0b11111111_11111111_11001110_00000000, 23), + (0b11111111_11111111_11101111_00000000, 24), + (0b11111111_11111111_01101000_00000000, 22), + (0b11111111_11111110_11101000_00000000, 21), + (0b11111111_11111110_10010000_00000000, 20), + (0b11111111_11111111_01101100_00000000, 22), + (0b11111111_11111111_01110000_00000000, 22), + (0b11111111_11111111_11010000_00000000, 23), + (0b11111111_11111111_11010010_00000000, 23), + (0b11111111_11111110_11110000_00000000, 21), + (0b11111111_11111111_11010100_00000000, 23), + (0b11111111_11111111_01110100_00000000, 22), + (0b11111111_11111111_01111000_00000000, 22), + (0b11111111_11111111_11110000_00000000, 24), + (0b11111111_11111110_11111000_00000000, 21), + (0b11111111_11111111_01111100_00000000, 22), + (0b11111111_11111111_11010110_00000000, 23), + (0b11111111_11111111_11011000_00000000, 23), + (0b11111111_11111111_00000000_00000000, 21), + (0b11111111_11111111_00001000_00000000, 21), + (0b11111111_11111111_10000000_00000000, 22), + (0b11111111_11111111_00010000_00000000, 21), + (0b11111111_11111111_11011010_00000000, 23), + (0b11111111_11111111_10000100_00000000, 22), + (0b11111111_11111111_11011100_00000000, 23), + (0b11111111_11111111_11011110_00000000, 23), + (0b11111111_11111110_10100000_00000000, 20), + (0b11111111_11111111_10001000_00000000, 22), + (0b11111111_11111111_10001100_00000000, 22), + (0b11111111_11111111_10010000_00000000, 22), + (0b11111111_11111111_11100000_00000000, 23), + (0b11111111_11111111_10010100_00000000, 22), + (0b11111111_11111111_10011000_00000000, 22), + (0b11111111_11111111_11100010_00000000, 23), + (0b11111111_11111111_11111000_00000000, 26), + (0b11111111_11111111_11111000_01000000, 26), + (0b11111111_11111110_10110000_00000000, 20), + (0b11111111_11111110_00100000_00000000, 19), + (0b11111111_11111111_10011100_00000000, 22), + (0b11111111_11111111_11100100_00000000, 23), + (0b11111111_11111111_10100000_00000000, 22), + (0b11111111_11111111_11110110_00000000, 25), + (0b11111111_11111111_11111000_10000000, 26), + (0b11111111_11111111_11111000_11000000, 26), + (0b11111111_11111111_11111001_00000000, 26), + (0b11111111_11111111_11111011_11000000, 27), + (0b11111111_11111111_11111011_11100000, 27), + (0b11111111_11111111_11111001_01000000, 26), + (0b11111111_11111111_11110001_00000000, 24), + (0b11111111_11111111_11110110_10000000, 25), + (0b11111111_11111110_01000000_00000000, 19), + (0b11111111_11111111_00011000_00000000, 21), + (0b11111111_11111111_11111001_10000000, 26), + (0b11111111_11111111_11111100_00000000, 27), + (0b11111111_11111111_11111100_00100000, 27), + (0b11111111_11111111_11111001_11000000, 26), + (0b11111111_11111111_11111100_01000000, 27), + (0b11111111_11111111_11110010_00000000, 24), + (0b11111111_11111111_00100000_00000000, 21), + (0b11111111_11111111_00101000_00000000, 21), + (0b11111111_11111111_11111010_00000000, 26), + (0b11111111_11111111_11111010_01000000, 26), + (0b11111111_11111111_11111111_11010000, 28), + (0b11111111_11111111_11111100_01100000, 27), + (0b11111111_11111111_11111100_10000000, 27), + (0b11111111_11111111_11111100_10100000, 27), + (0b11111111_11111110_11000000_00000000, 20), + (0b11111111_11111111_11110011_00000000, 24), + (0b11111111_11111110_11010000_00000000, 20), + (0b11111111_11111111_00110000_00000000, 21), + (0b11111111_11111111_10100100_00000000, 22), + (0b11111111_11111111_00111000_00000000, 21), + (0b11111111_11111111_01000000_00000000, 21), + (0b11111111_11111111_11100110_00000000, 23), + (0b11111111_11111111_10101000_00000000, 22), + (0b11111111_11111111_10101100_00000000, 22), + (0b11111111_11111111_11110111_00000000, 25), + (0b11111111_11111111_11110111_10000000, 25), + (0b11111111_11111111_11110100_00000000, 24), + (0b11111111_11111111_11110101_00000000, 24), + (0b11111111_11111111_11111010_10000000, 26), + (0b11111111_11111111_11101000_00000000, 23), + (0b11111111_11111111_11111010_11000000, 26), + (0b11111111_11111111_11111100_11000000, 27), + (0b11111111_11111111_11111011_00000000, 26), + (0b11111111_11111111_11111011_01000000, 26), + (0b11111111_11111111_11111100_11100000, 27), + (0b11111111_11111111_11111101_00000000, 27), + (0b11111111_11111111_11111101_00100000, 27), + (0b11111111_11111111_11111101_01000000, 27), + (0b11111111_11111111_11111101_01100000, 27), + (0b11111111_11111111_11111111_11100000, 28), + (0b11111111_11111111_11111101_10000000, 27), + (0b11111111_11111111_11111101_10100000, 27), + (0b11111111_11111111_11111101_11000000, 27), + (0b11111111_11111111_11111101_11100000, 27), + (0b11111111_11111111_11111110_00000000, 27), + (0b11111111_11111111_11111011_10000000, 26), + (0b11111111_11111111_11111111_11111100, 30) + }; + + private static readonly (int codeLength, int[] codes)[] _decodingTable = new[] + { + (5, new[] { 48, 49, 50, 97, 99, 101, 105, 111, 115, 116 }), + (6, new[] { 32, 37, 45, 46, 47, 51, 52, 53, 54, 55, 56, 57, 61, 65, 95, 98, 100, 102, 103, 104, 108, 109, 110, 112, 114, 117 }), + (7, new[] { 58, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 106, 107, 113, 118, 119, 120, 121, 122 }), + (8, new[] { 38, 42, 44, 59, 88, 90 }), + (10, new[] { 33, 34, 40, 41, 63 }), + (11, new[] { 39, 43, 124 }), + (12, new[] { 35, 62 }), + (13, new[] { 0, 36, 64, 91, 93, 126 }), + (14, new[] { 94, 125 }), + (15, new[] { 60, 96, 123 }), + (19, new[] { 92, 195, 208 }), + (20, new[] { 128, 130, 131, 162, 184, 194, 224, 226 }), + (21, new[] { 153, 161, 167, 172, 176, 177, 179, 209, 216, 217, 227, 229, 230 }), + (22, new[] { 129, 132, 133, 134, 136, 146, 154, 156, 160, 163, 164, 169, 170, 173, 178, 181, 185, 186, 187, 189, 190, 196, 198, 228, 232, 233 }), + (23, new[] { 1, 135, 137, 138, 139, 140, 141, 143, 147, 149, 150, 151, 152, 155, 157, 158, 165, 166, 168, 174, 175, 180, 182, 183, 188, 191, 197, 231, 239 }), + (24, new[] { 9, 142, 144, 145, 148, 159, 171, 206, 215, 225, 236, 237 }), + (25, new[] { 199, 207, 234, 235 }), + (26, new[] { 192, 193, 200, 201, 202, 205, 210, 213, 218, 219, 238, 240, 242, 243, 255 }), + (27, new[] { 203, 204, 211, 212, 214, 221, 222, 223, 241, 244, 245, 246, 247, 248, 250, 251, 252, 253, 254 }), + (28, new[] { 2, 3, 4, 5, 6, 7, 8, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 127, 220, 249 }), + (30, new[] { 10, 13, 22, 256 }) + }; + + public static (uint encoded, int bitLength) Encode(int data) + { + return _encodingTable[data]; + } + + /// + /// Decodes a Huffman encoded string from a byte array. + /// + /// The source byte array containing the encoded data. + /// The offset in the byte array where the coded data starts. + /// The number of bytes to decode. + /// The destination byte array to store the decoded data. + /// The number of decoded symbols. + public static int Decode(byte[] src, int offset, int count, byte[] dst) + { + var i = offset; + var j = 0; + var lastDecodedBits = 0; + while (i < count) + { + var next = (uint)(src[i] << 24 + lastDecodedBits); + next |= (i + 1 < src.Length ? (uint)(src[i + 1] << 16 + lastDecodedBits) : 0); + next |= (i + 2 < src.Length ? (uint)(src[i + 2] << 8 + lastDecodedBits) : 0); + next |= (i + 3 < src.Length ? (uint)(src[i + 3] << lastDecodedBits) : 0); + + var ones = (uint)(int.MinValue >> (8 - lastDecodedBits - 1)); + if (i == count - 1 && lastDecodedBits > 0 && (next & ones) == ones) + { + // The remaining 7 or less bits are all 1, which is padding. + // We specifically check that lastDecodedBits > 0 because padding + // longer than 7 bits should be treated as a decoding error. + // http://httpwg.org/specs/rfc7541.html#rfc.section.5.2 + break; + } + + // The longest possible symbol size is 30 bits. If we're at the last 4 bytes + // of the input, we need to make sure we pass the correct number of valid bits + // left, otherwise the trailing 0s in next may form a valid symbol. + var validBits = Math.Min(30, (8 - lastDecodedBits) + (count - i - 1) * 8); + var ch = Decode(next, validBits, out var decodedBits); + + if (ch == -1) + { + // No valid symbol could be decoded with the bits in next + throw new HuffmanDecodingException(CoreStrings.HPackHuffmanErrorIncomplete); + } + else if (ch == 256) + { + // A Huffman-encoded string literal containing the EOS symbol MUST be treated as a decoding error. + // http://httpwg.org/specs/rfc7541.html#rfc.section.5.2 + throw new HuffmanDecodingException(CoreStrings.HPackHuffmanErrorEOS); + } + + if (j == dst.Length) + { + throw new HuffmanDecodingException(CoreStrings.HPackHuffmanErrorDestinationTooSmall); + } + + dst[j++] = (byte)ch; + + // If we crossed a byte boundary, advance i so we start at the next byte that's not fully decoded. + lastDecodedBits += decodedBits; + i += lastDecodedBits / 8; + + // Modulo 8 since we only care about how many bits were decoded in the last byte that we processed. + lastDecodedBits %= 8; + } + + return j; + } + + /// + /// Decodes a single symbol from a 32-bit word. + /// + /// A 32-bit word containing a Huffman encoded symbol. + /// + /// The number of bits in that may contain an encoded symbol. + /// This is not the exact number of bits that encode the symbol. Instead, it prevents + /// decoding the lower bits of if they don't contain any + /// encoded data. + /// + /// The number of bits decoded from . + /// The decoded symbol. + public static int Decode(uint data, int validBits, out int decodedBits) + { + // The code below implements the decoding logic for a canonical Huffman code. + // + // To decode a symbol, we scan the decoding table, which is sorted by ascending symbol bit length. + // For each bit length b, we determine the maximum b-bit encoded value, plus one (that is codeMax). + // This is done with the following logic: + // + // if we're at the first entry in the table, + // codeMax = the # of symbols encoded in b bits + // else, + // left-shift codeMax by the difference between b and the previous entry's bit length, + // then increment codeMax by the # of symbols encoded in b bits + // + // Next, we look at the value v encoded in the highest b bits of data. If v is less than codeMax, + // those bits correspond to a Huffman encoded symbol. We find the corresponding decoded + // symbol in the list of values associated with bit length b in the decoding table by indexing it + // with codeMax - v. + + var codeMax = 0; + + for (var i = 0; i < _decodingTable.Length && _decodingTable[i].codeLength <= validBits; i++) + { + var (codeLength, codes) = _decodingTable[i]; + + if (i > 0) + { + codeMax <<= codeLength - _decodingTable[i - 1].codeLength; + } + + codeMax += codes.Length; + + var mask = int.MinValue >> (codeLength - 1); + var masked = (data & mask) >> (32 - codeLength); + + if (masked < codeMax) + { + decodedBits = codeLength; + return codes[codes.Length - (codeMax - masked)]; + } + } + + decodedBits = 0; + return -1; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HuffmanDecodingException.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HuffmanDecodingException.cs new file mode 100644 index 0000000000..3bd992ab4b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/HuffmanDecodingException.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HuffmanDecodingException : Exception + { + public HuffmanDecodingException(string message) + : base(message) + { + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerDecoder.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerDecoder.cs new file mode 100644 index 0000000000..5bc051a9a3 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerDecoder.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class IntegerDecoder + { + private int _i; + private int _m; + + public int Value { get; private set; } + + public bool BeginDecode(byte b, int prefixLength) + { + if (b < ((1 << prefixLength) - 1)) + { + Value = b; + return true; + } + else + { + _i = b; + _m = 0; + return false; + } + } + + public bool Decode(byte b) + { + _i = _i + (b & 127) * (1 << _m); + _m = _m + 7; + + if ((b & 128) != 128) + { + Value = _i; + return true; + } + + return false; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerEncoder.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerEncoder.cs new file mode 100644 index 0000000000..6385459d14 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/IntegerEncoder.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public static class IntegerEncoder + { + public static bool Encode(int i, int n, Span buffer, out int length) + { + var j = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + if (i < (1 << n) - 1) + { + buffer[j] &= MaskHigh(8 - n); + buffer[j++] |= (byte)i; + } + else + { + buffer[j] &= MaskHigh(8 - n); + buffer[j++] |= (byte)((1 << n) - 1); + + if (j == buffer.Length) + { + return false; + } + + i = i - ((1 << n) - 1); + while (i >= 128) + { + buffer[j++] = (byte)(i % 128 + 128); + + if (j > buffer.Length) + { + return false; + } + + i = i / 128; + } + buffer[j++] = (byte)i; + } + + length = j; + return true; + } + + private static byte MaskHigh(int n) + { + return (byte)(sbyte.MinValue >> (n - 1)); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StaticTable.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StaticTable.cs new file mode 100644 index 0000000000..c28a78ff8d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StaticTable.cs @@ -0,0 +1,104 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class StaticTable + { + private static readonly StaticTable _instance = new StaticTable(); + + private readonly Dictionary _statusIndex = new Dictionary + { + [200] = 8, + [204] = 9, + [206] = 10, + [304] = 11, + [400] = 12, + [404] = 13, + [500] = 14, + }; + + private StaticTable() + { + } + + public static StaticTable Instance => _instance; + + public int Count => _staticTable.Length; + + public HeaderField this[int index] => _staticTable[index]; + + public IReadOnlyDictionary StatusIndex => _statusIndex; + + private readonly HeaderField[] _staticTable = new HeaderField[] + { + CreateHeaderField(":authority", ""), + CreateHeaderField(":method", "GET"), + CreateHeaderField(":method", "POST"), + CreateHeaderField(":path", "/"), + CreateHeaderField(":path", "/index.html"), + CreateHeaderField(":scheme", "http"), + CreateHeaderField(":scheme", "https"), + CreateHeaderField(":status", "200"), + CreateHeaderField(":status", "204"), + CreateHeaderField(":status", "206"), + CreateHeaderField(":status", "304"), + CreateHeaderField(":status", "400"), + CreateHeaderField(":status", "404"), + CreateHeaderField(":status", "500"), + CreateHeaderField("accept-charset", ""), + CreateHeaderField("accept-encoding", "gzip, deflate"), + CreateHeaderField("accept-language", ""), + CreateHeaderField("accept-ranges", ""), + CreateHeaderField("accept", ""), + CreateHeaderField("access-control-allow-origin", ""), + CreateHeaderField("age", ""), + CreateHeaderField("allow", ""), + CreateHeaderField("authorization", ""), + CreateHeaderField("cache-control", ""), + CreateHeaderField("content-disposition", ""), + CreateHeaderField("content-encoding", ""), + CreateHeaderField("content-language", ""), + CreateHeaderField("content-length", ""), + CreateHeaderField("content-location", ""), + CreateHeaderField("content-range", ""), + CreateHeaderField("content-type", ""), + CreateHeaderField("cookie", ""), + CreateHeaderField("date", ""), + CreateHeaderField("etag", ""), + CreateHeaderField("expect", ""), + CreateHeaderField("expires", ""), + CreateHeaderField("from", ""), + CreateHeaderField("host", ""), + CreateHeaderField("if-match", ""), + CreateHeaderField("if-modified-since", ""), + CreateHeaderField("if-none-match", ""), + CreateHeaderField("if-range", ""), + CreateHeaderField("if-unmodifiedsince", ""), + CreateHeaderField("last-modified", ""), + CreateHeaderField("link", ""), + CreateHeaderField("location", ""), + CreateHeaderField("max-forwards", ""), + CreateHeaderField("proxy-authenticate", ""), + CreateHeaderField("proxy-authorization", ""), + CreateHeaderField("range", ""), + CreateHeaderField("referer", ""), + CreateHeaderField("refresh", ""), + CreateHeaderField("retry-after", ""), + CreateHeaderField("server", ""), + CreateHeaderField("set-cookie", ""), + CreateHeaderField("strict-transport-security", ""), + CreateHeaderField("transfer-encoding", ""), + CreateHeaderField("user-agent", ""), + CreateHeaderField("vary", ""), + CreateHeaderField("via", ""), + CreateHeaderField("www-authenticate", "") + }; + + private static HeaderField CreateHeaderField(string name, string value) + => new HeaderField(Encoding.ASCII.GetBytes(name), Encoding.ASCII.GetBytes(value)); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StatusCodes.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StatusCodes.cs new file mode 100644 index 0000000000..056d5a8a1a --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/HPack/StatusCodes.cs @@ -0,0 +1,222 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public static class StatusCodes + { + private static readonly byte[] _bytesStatus100 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status100Continue); + private static readonly byte[] _bytesStatus101 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status101SwitchingProtocols); + private static readonly byte[] _bytesStatus102 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status102Processing); + + private static readonly byte[] _bytesStatus200 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status200OK); + private static readonly byte[] _bytesStatus201 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status201Created); + private static readonly byte[] _bytesStatus202 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status202Accepted); + private static readonly byte[] _bytesStatus203 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status203NonAuthoritative); + private static readonly byte[] _bytesStatus204 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status204NoContent); + private static readonly byte[] _bytesStatus205 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status205ResetContent); + private static readonly byte[] _bytesStatus206 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status206PartialContent); + private static readonly byte[] _bytesStatus207 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status207MultiStatus); + private static readonly byte[] _bytesStatus208 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status208AlreadyReported); + private static readonly byte[] _bytesStatus226 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status226IMUsed); + + private static readonly byte[] _bytesStatus300 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status300MultipleChoices); + private static readonly byte[] _bytesStatus301 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status301MovedPermanently); + private static readonly byte[] _bytesStatus302 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status302Found); + private static readonly byte[] _bytesStatus303 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status303SeeOther); + private static readonly byte[] _bytesStatus304 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status304NotModified); + private static readonly byte[] _bytesStatus305 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status305UseProxy); + private static readonly byte[] _bytesStatus306 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status306SwitchProxy); + private static readonly byte[] _bytesStatus307 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status307TemporaryRedirect); + private static readonly byte[] _bytesStatus308 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status308PermanentRedirect); + + private static readonly byte[] _bytesStatus400 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest); + private static readonly byte[] _bytesStatus401 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status401Unauthorized); + private static readonly byte[] _bytesStatus402 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status402PaymentRequired); + private static readonly byte[] _bytesStatus403 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status403Forbidden); + private static readonly byte[] _bytesStatus404 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status404NotFound); + private static readonly byte[] _bytesStatus405 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status405MethodNotAllowed); + private static readonly byte[] _bytesStatus406 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status406NotAcceptable); + private static readonly byte[] _bytesStatus407 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status407ProxyAuthenticationRequired); + private static readonly byte[] _bytesStatus408 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status408RequestTimeout); + private static readonly byte[] _bytesStatus409 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status409Conflict); + private static readonly byte[] _bytesStatus410 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status410Gone); + private static readonly byte[] _bytesStatus411 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status411LengthRequired); + private static readonly byte[] _bytesStatus412 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status412PreconditionFailed); + private static readonly byte[] _bytesStatus413 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status413PayloadTooLarge); + private static readonly byte[] _bytesStatus414 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status414UriTooLong); + private static readonly byte[] _bytesStatus415 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status415UnsupportedMediaType); + private static readonly byte[] _bytesStatus416 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status416RangeNotSatisfiable); + private static readonly byte[] _bytesStatus417 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status417ExpectationFailed); + private static readonly byte[] _bytesStatus418 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status418ImATeapot); + private static readonly byte[] _bytesStatus419 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status419AuthenticationTimeout); + private static readonly byte[] _bytesStatus421 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status421MisdirectedRequest); + private static readonly byte[] _bytesStatus422 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status422UnprocessableEntity); + private static readonly byte[] _bytesStatus423 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status423Locked); + private static readonly byte[] _bytesStatus424 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status424FailedDependency); + private static readonly byte[] _bytesStatus426 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status426UpgradeRequired); + private static readonly byte[] _bytesStatus428 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status428PreconditionRequired); + private static readonly byte[] _bytesStatus429 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status429TooManyRequests); + private static readonly byte[] _bytesStatus431 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status431RequestHeaderFieldsTooLarge); + private static readonly byte[] _bytesStatus451 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status451UnavailableForLegalReasons); + + private static readonly byte[] _bytesStatus500 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError); + private static readonly byte[] _bytesStatus501 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status501NotImplemented); + private static readonly byte[] _bytesStatus502 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status502BadGateway); + private static readonly byte[] _bytesStatus503 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status503ServiceUnavailable); + private static readonly byte[] _bytesStatus504 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status504GatewayTimeout); + private static readonly byte[] _bytesStatus505 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status505HttpVersionNotsupported); + private static readonly byte[] _bytesStatus506 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status506VariantAlsoNegotiates); + private static readonly byte[] _bytesStatus507 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status507InsufficientStorage); + private static readonly byte[] _bytesStatus508 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status508LoopDetected); + private static readonly byte[] _bytesStatus510 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status510NotExtended); + private static readonly byte[] _bytesStatus511 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status511NetworkAuthenticationRequired); + + private static byte[] CreateStatusBytes(int statusCode) + { + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture)); + } + + public static byte[] ToStatusBytes(int statusCode) + { + switch (statusCode) + { + case Microsoft.AspNetCore.Http.StatusCodes.Status100Continue: + return _bytesStatus100; + case Microsoft.AspNetCore.Http.StatusCodes.Status101SwitchingProtocols: + return _bytesStatus101; + case Microsoft.AspNetCore.Http.StatusCodes.Status102Processing: + return _bytesStatus102; + + case Microsoft.AspNetCore.Http.StatusCodes.Status200OK: + return _bytesStatus200; + case Microsoft.AspNetCore.Http.StatusCodes.Status201Created: + return _bytesStatus201; + case Microsoft.AspNetCore.Http.StatusCodes.Status202Accepted: + return _bytesStatus202; + case Microsoft.AspNetCore.Http.StatusCodes.Status203NonAuthoritative: + return _bytesStatus203; + case Microsoft.AspNetCore.Http.StatusCodes.Status204NoContent: + return _bytesStatus204; + case Microsoft.AspNetCore.Http.StatusCodes.Status205ResetContent: + return _bytesStatus205; + case Microsoft.AspNetCore.Http.StatusCodes.Status206PartialContent: + return _bytesStatus206; + case Microsoft.AspNetCore.Http.StatusCodes.Status207MultiStatus: + return _bytesStatus207; + case Microsoft.AspNetCore.Http.StatusCodes.Status208AlreadyReported: + return _bytesStatus208; + case Microsoft.AspNetCore.Http.StatusCodes.Status226IMUsed: + return _bytesStatus226; + + case Microsoft.AspNetCore.Http.StatusCodes.Status300MultipleChoices: + return _bytesStatus300; + case Microsoft.AspNetCore.Http.StatusCodes.Status301MovedPermanently: + return _bytesStatus301; + case Microsoft.AspNetCore.Http.StatusCodes.Status302Found: + return _bytesStatus302; + case Microsoft.AspNetCore.Http.StatusCodes.Status303SeeOther: + return _bytesStatus303; + case Microsoft.AspNetCore.Http.StatusCodes.Status304NotModified: + return _bytesStatus304; + case Microsoft.AspNetCore.Http.StatusCodes.Status305UseProxy: + return _bytesStatus305; + case Microsoft.AspNetCore.Http.StatusCodes.Status306SwitchProxy: + return _bytesStatus306; + case Microsoft.AspNetCore.Http.StatusCodes.Status307TemporaryRedirect: + return _bytesStatus307; + case Microsoft.AspNetCore.Http.StatusCodes.Status308PermanentRedirect: + return _bytesStatus308; + + case Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest: + return _bytesStatus400; + case Microsoft.AspNetCore.Http.StatusCodes.Status401Unauthorized: + return _bytesStatus401; + case Microsoft.AspNetCore.Http.StatusCodes.Status402PaymentRequired: + return _bytesStatus402; + case Microsoft.AspNetCore.Http.StatusCodes.Status403Forbidden: + return _bytesStatus403; + case Microsoft.AspNetCore.Http.StatusCodes.Status404NotFound: + return _bytesStatus404; + case Microsoft.AspNetCore.Http.StatusCodes.Status405MethodNotAllowed: + return _bytesStatus405; + case Microsoft.AspNetCore.Http.StatusCodes.Status406NotAcceptable: + return _bytesStatus406; + case Microsoft.AspNetCore.Http.StatusCodes.Status407ProxyAuthenticationRequired: + return _bytesStatus407; + case Microsoft.AspNetCore.Http.StatusCodes.Status408RequestTimeout: + return _bytesStatus408; + case Microsoft.AspNetCore.Http.StatusCodes.Status409Conflict: + return _bytesStatus409; + case Microsoft.AspNetCore.Http.StatusCodes.Status410Gone: + return _bytesStatus410; + case Microsoft.AspNetCore.Http.StatusCodes.Status411LengthRequired: + return _bytesStatus411; + case Microsoft.AspNetCore.Http.StatusCodes.Status412PreconditionFailed: + return _bytesStatus412; + case Microsoft.AspNetCore.Http.StatusCodes.Status413PayloadTooLarge: + return _bytesStatus413; + case Microsoft.AspNetCore.Http.StatusCodes.Status414UriTooLong: + return _bytesStatus414; + case Microsoft.AspNetCore.Http.StatusCodes.Status415UnsupportedMediaType: + return _bytesStatus415; + case Microsoft.AspNetCore.Http.StatusCodes.Status416RangeNotSatisfiable: + return _bytesStatus416; + case Microsoft.AspNetCore.Http.StatusCodes.Status417ExpectationFailed: + return _bytesStatus417; + case Microsoft.AspNetCore.Http.StatusCodes.Status418ImATeapot: + return _bytesStatus418; + case Microsoft.AspNetCore.Http.StatusCodes.Status419AuthenticationTimeout: + return _bytesStatus419; + case Microsoft.AspNetCore.Http.StatusCodes.Status421MisdirectedRequest: + return _bytesStatus421; + case Microsoft.AspNetCore.Http.StatusCodes.Status422UnprocessableEntity: + return _bytesStatus422; + case Microsoft.AspNetCore.Http.StatusCodes.Status423Locked: + return _bytesStatus423; + case Microsoft.AspNetCore.Http.StatusCodes.Status424FailedDependency: + return _bytesStatus424; + case Microsoft.AspNetCore.Http.StatusCodes.Status426UpgradeRequired: + return _bytesStatus426; + case Microsoft.AspNetCore.Http.StatusCodes.Status428PreconditionRequired: + return _bytesStatus428; + case Microsoft.AspNetCore.Http.StatusCodes.Status429TooManyRequests: + return _bytesStatus429; + case Microsoft.AspNetCore.Http.StatusCodes.Status431RequestHeaderFieldsTooLarge: + return _bytesStatus431; + case Microsoft.AspNetCore.Http.StatusCodes.Status451UnavailableForLegalReasons: + return _bytesStatus451; + + case Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError: + return _bytesStatus500; + case Microsoft.AspNetCore.Http.StatusCodes.Status501NotImplemented: + return _bytesStatus501; + case Microsoft.AspNetCore.Http.StatusCodes.Status502BadGateway: + return _bytesStatus502; + case Microsoft.AspNetCore.Http.StatusCodes.Status503ServiceUnavailable: + return _bytesStatus503; + case Microsoft.AspNetCore.Http.StatusCodes.Status504GatewayTimeout: + return _bytesStatus504; + case Microsoft.AspNetCore.Http.StatusCodes.Status505HttpVersionNotsupported: + return _bytesStatus505; + case Microsoft.AspNetCore.Http.StatusCodes.Status506VariantAlsoNegotiates: + return _bytesStatus506; + case Microsoft.AspNetCore.Http.StatusCodes.Status507InsufficientStorage: + return _bytesStatus507; + case Microsoft.AspNetCore.Http.StatusCodes.Status508LoopDetected: + return _bytesStatus508; + case Microsoft.AspNetCore.Http.StatusCodes.Status510NotExtended: + return _bytesStatus510; + case Microsoft.AspNetCore.Http.StatusCodes.Status511NetworkAuthenticationRequired: + return _bytesStatus511; + + default: + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture)); + + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs new file mode 100644 index 0000000000..e3845ddfe2 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs @@ -0,0 +1,877 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Concurrent; +using System.IO.Pipelines; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2Connection : ITimeoutControl, IHttp2StreamLifetimeHandler, IHttpHeadersHandler, IRequestProcessor + { + private enum RequestHeaderParsingState + { + Ready, + PseudoHeaderFields, + Headers, + Trailers + } + + [Flags] + private enum PseudoHeaderFields + { + None = 0x0, + Authority = 0x1, + Method = 0x2, + Path = 0x4, + Scheme = 0x8, + Status = 0x10, + Unknown = 0x40000000 + } + + public static byte[] ClientPreface { get; } = Encoding.ASCII.GetBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"); + + private static readonly PseudoHeaderFields _mandatoryRequestPseudoHeaderFields = + PseudoHeaderFields.Method | PseudoHeaderFields.Path | PseudoHeaderFields.Scheme; + + private static readonly byte[] _authorityBytes = Encoding.ASCII.GetBytes("authority"); + private static readonly byte[] _methodBytes = Encoding.ASCII.GetBytes("method"); + private static readonly byte[] _pathBytes = Encoding.ASCII.GetBytes("path"); + private static readonly byte[] _schemeBytes = Encoding.ASCII.GetBytes("scheme"); + private static readonly byte[] _statusBytes = Encoding.ASCII.GetBytes("status"); + private static readonly byte[] _connectionBytes = Encoding.ASCII.GetBytes("connection"); + private static readonly byte[] _teBytes = Encoding.ASCII.GetBytes("te"); + private static readonly byte[] _trailersBytes = Encoding.ASCII.GetBytes("trailers"); + private static readonly byte[] _connectBytes = Encoding.ASCII.GetBytes("CONNECT"); + + private readonly Http2ConnectionContext _context; + private readonly Http2FrameWriter _frameWriter; + private readonly HPackDecoder _hpackDecoder; + + private readonly Http2PeerSettings _serverSettings = new Http2PeerSettings(); + private readonly Http2PeerSettings _clientSettings = new Http2PeerSettings(); + + private readonly Http2Frame _incomingFrame = new Http2Frame(); + + private Http2Stream _currentHeadersStream; + private RequestHeaderParsingState _requestHeaderParsingState; + private PseudoHeaderFields _parsedPseudoHeaderFields; + private bool _isMethodConnect; + private int _highestOpenedStreamId; + + private bool _stopping; + + private readonly ConcurrentDictionary _streams = new ConcurrentDictionary(); + + public Http2Connection(Http2ConnectionContext context) + { + _context = context; + _frameWriter = new Http2FrameWriter(context.Transport.Output, context.Application.Input); + _hpackDecoder = new HPackDecoder((int)_serverSettings.HeaderTableSize); + } + + public string ConnectionId => _context.ConnectionId; + + public PipeReader Input => _context.Transport.Input; + + public IKestrelTrace Log => _context.ServiceContext.Log; + + public IFeatureCollection ConnectionFeatures => _context.ConnectionFeatures; + + public void OnInputOrOutputCompleted() + { + _stopping = true; + _frameWriter.Abort(ex: null); + } + + public void Abort(ConnectionAbortedException ex) + { + _stopping = true; + _frameWriter.Abort(ex); + } + + public void StopProcessingNextRequest() + { + _stopping = true; + Input.CancelPendingRead(); + } + + public async Task ProcessRequestsAsync(IHttpApplication application) + { + Exception error = null; + var errorCode = Http2ErrorCode.NO_ERROR; + + try + { + while (!_stopping) + { + var result = await Input.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty && ParsePreface(readableBuffer, out consumed, out examined)) + { + break; + } + + if (result.IsCompleted) + { + return; + } + } + finally + { + Input.AdvanceTo(consumed, examined); + } + } + + if (!_stopping) + { + await _frameWriter.WriteSettingsAsync(_serverSettings); + } + + while (!_stopping) + { + var result = await Input.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty && Http2FrameReader.ReadFrame(readableBuffer, _incomingFrame, out consumed, out examined)) + { + Log.LogTrace($"Connection id {ConnectionId} received {_incomingFrame.Type} frame with flags 0x{_incomingFrame.Flags:x} and length {_incomingFrame.Length} for stream ID {_incomingFrame.StreamId}"); + await ProcessFrameAsync(application); + } + else if (result.IsCompleted) + { + return; + } + } + finally + { + Input.AdvanceTo(consumed, examined); + } + } + } + catch (ConnectionResetException ex) + { + // Don't log ECONNRESET errors when there are no active streams on the connection. Browsers like IE will reset connections regularly. + if (_streams.Count > 0) + { + Log.RequestProcessingError(ConnectionId, ex); + } + + error = ex; + } + catch (Http2ConnectionErrorException ex) + { + Log.Http2ConnectionError(ConnectionId, ex); + error = ex; + errorCode = ex.ErrorCode; + } + catch (HPackDecodingException ex) + { + Log.HPackDecodingError(ConnectionId, _currentHeadersStream.StreamId, ex); + error = ex; + errorCode = Http2ErrorCode.COMPRESSION_ERROR; + } + catch (Exception ex) + { + error = ex; + errorCode = Http2ErrorCode.INTERNAL_ERROR; + throw; + } + finally + { + try + { + foreach (var stream in _streams.Values) + { + stream.Http2Abort(error); + } + + await _frameWriter.WriteGoAwayAsync(_highestOpenedStreamId, errorCode); + } + finally + { + Input.Complete(); + _frameWriter.Abort(ex: null); + } + } + } + + private bool ParsePreface(ReadOnlySequence readableBuffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = readableBuffer.Start; + examined = readableBuffer.End; + + if (readableBuffer.Length < ClientPreface.Length) + { + return false; + } + + var span = readableBuffer.IsSingleSegment + ? readableBuffer.First.Span + : readableBuffer.ToSpan(); + + for (var i = 0; i < ClientPreface.Length; i++) + { + if (ClientPreface[i] != span[i]) + { + throw new Exception("Invalid HTTP/2 connection preface."); + } + } + + consumed = examined = readableBuffer.GetPosition(ClientPreface.Length); + return true; + } + + private Task ProcessFrameAsync(IHttpApplication application) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1 + // Streams initiated by a client MUST use odd-numbered stream identifiers; ... + // An endpoint that receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + if (_incomingFrame.StreamId != 0 && (_incomingFrame.StreamId & 1) == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdEven(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + switch (_incomingFrame.Type) + { + case Http2FrameType.DATA: + return ProcessDataFrameAsync(); + case Http2FrameType.HEADERS: + return ProcessHeadersFrameAsync(application); + case Http2FrameType.PRIORITY: + return ProcessPriorityFrameAsync(); + case Http2FrameType.RST_STREAM: + return ProcessRstStreamFrameAsync(); + case Http2FrameType.SETTINGS: + return ProcessSettingsFrameAsync(); + case Http2FrameType.PUSH_PROMISE: + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorPushPromiseReceived, Http2ErrorCode.PROTOCOL_ERROR); + case Http2FrameType.PING: + return ProcessPingFrameAsync(); + case Http2FrameType.GOAWAY: + return ProcessGoAwayFrameAsync(); + case Http2FrameType.WINDOW_UPDATE: + return ProcessWindowUpdateFrameAsync(); + case Http2FrameType.CONTINUATION: + return ProcessContinuationFrameAsync(application); + default: + return ProcessUnknownFrameAsync(); + } + } + + private Task ProcessDataFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.DataHasPadding && _incomingFrame.DataPadLength >= _incomingFrame.Length) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorPaddingTooLong(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + ThrowIfIncomingFrameSentToIdleStream(); + + if (_streams.TryGetValue(_incomingFrame.StreamId, out var stream)) + { + if (stream.EndStreamReceived) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1 + // + // ...an endpoint that receives any frames after receiving a frame with the + // END_STREAM flag set MUST treat that as a connection error (Section 5.4.1) + // of type STREAM_CLOSED, unless the frame is permitted as described below. + // + // (The allowed frame types for this situation are WINDOW_UPDATE, RST_STREAM and PRIORITY) + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamHalfClosedRemote(_incomingFrame.Type, stream.StreamId), Http2ErrorCode.STREAM_CLOSED); + } + + return stream.OnDataAsync(_incomingFrame.DataPayload, + endStream: (_incomingFrame.DataFlags & Http2DataFrameFlags.END_STREAM) == Http2DataFrameFlags.END_STREAM); + } + + // If we couldn't find the stream, it was either alive previously but closed with + // END_STREAM or RST_STREAM, or it was implicitly closed when the client opened + // a new stream with a higher ID. Per the spec, we should send RST_STREAM if + // the stream was closed with RST_STREAM or implicitly, but the spec also says + // in http://httpwg.org/specs/rfc7540.html#rfc.section.5.4.1 that + // + // An endpoint can end a connection at any time. In particular, an endpoint MAY + // choose to treat a stream error as a connection error. + // + // We choose to do that here so we don't have to keep state to track implicitly closed + // streams vs. streams closed with END_STREAM or RST_STREAM. + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamClosed(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.STREAM_CLOSED); + } + + private async Task ProcessHeadersFrameAsync(IHttpApplication application) + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.HeadersHasPadding && _incomingFrame.HeadersPadLength >= _incomingFrame.Length) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorPaddingTooLong(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.HeadersHasPriority && _incomingFrame.HeadersStreamDependency == _incomingFrame.StreamId) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamSelfDependency(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_streams.TryGetValue(_incomingFrame.StreamId, out var stream)) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1 + // + // ...an endpoint that receives any frames after receiving a frame with the + // END_STREAM flag set MUST treat that as a connection error (Section 5.4.1) + // of type STREAM_CLOSED, unless the frame is permitted as described below. + // + // (The allowed frame types after END_STREAM are WINDOW_UPDATE, RST_STREAM and PRIORITY) + if (stream.EndStreamReceived) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamHalfClosedRemote(_incomingFrame.Type, stream.StreamId), Http2ErrorCode.STREAM_CLOSED); + } + + // This is the last chance for the client to send END_STREAM + if ((_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_STREAM) == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorHeadersWithTrailersNoEndStream, Http2ErrorCode.PROTOCOL_ERROR); + } + + // Since we found an active stream, this HEADERS frame contains trailers + _currentHeadersStream = stream; + _requestHeaderParsingState = RequestHeaderParsingState.Trailers; + + var endHeaders = (_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_HEADERS) == Http2HeadersFrameFlags.END_HEADERS; + await DecodeTrailersAsync(endHeaders, _incomingFrame.HeadersPayload); + } + else if (_incomingFrame.StreamId <= _highestOpenedStreamId) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1 + // + // The first use of a new stream identifier implicitly closes all streams in the "idle" + // state that might have been initiated by that peer with a lower-valued stream identifier. + // + // If we couldn't find the stream, it was previously closed (either implicitly or with + // END_STREAM or RST_STREAM). + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamClosed(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.STREAM_CLOSED); + } + else + { + // Start a new stream + _currentHeadersStream = new Http2Stream(new Http2StreamContext + { + ConnectionId = ConnectionId, + StreamId = _incomingFrame.StreamId, + ServiceContext = _context.ServiceContext, + ConnectionFeatures = _context.ConnectionFeatures, + MemoryPool = _context.MemoryPool, + LocalEndPoint = _context.LocalEndPoint, + RemoteEndPoint = _context.RemoteEndPoint, + StreamLifetimeHandler = this, + FrameWriter = _frameWriter + }); + + if ((_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_STREAM) == Http2HeadersFrameFlags.END_STREAM) + { + await _currentHeadersStream.OnDataAsync(Constants.EmptyData, endStream: true); + } + + _currentHeadersStream.Reset(); + + var endHeaders = (_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_HEADERS) == Http2HeadersFrameFlags.END_HEADERS; + await DecodeHeadersAsync(application, endHeaders, _incomingFrame.HeadersPayload); + } + } + + private Task ProcessPriorityFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.PriorityStreamDependency == _incomingFrame.StreamId) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamSelfDependency(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 5) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(_incomingFrame.Type, 5), Http2ErrorCode.FRAME_SIZE_ERROR); + } + + return Task.CompletedTask; + } + + private Task ProcessRstStreamFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 4) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(_incomingFrame.Type, 4), Http2ErrorCode.FRAME_SIZE_ERROR); + } + + ThrowIfIncomingFrameSentToIdleStream(); + + if (_streams.TryGetValue(_incomingFrame.StreamId, out var stream)) + { + stream.Abort(abortReason: null); + } + + return Task.CompletedTask; + } + + private Task ProcessSettingsFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId != 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdNotZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if ((_incomingFrame.SettingsFlags & Http2SettingsFrameFlags.ACK) == Http2SettingsFrameFlags.ACK && _incomingFrame.Length != 0) + { + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorSettingsAckLengthNotZero, Http2ErrorCode.FRAME_SIZE_ERROR); + } + + if (_incomingFrame.Length % 6 != 0) + { + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorSettingsLengthNotMultipleOfSix, Http2ErrorCode.FRAME_SIZE_ERROR); + } + + try + { + _clientSettings.ParseFrame(_incomingFrame); + return _frameWriter.WriteSettingsAckAsync(); + } + catch (Http2SettingsParameterOutOfRangeException ex) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorSettingsParameterOutOfRange(ex.Parameter), ex.Parameter == Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE + ? Http2ErrorCode.FLOW_CONTROL_ERROR + : Http2ErrorCode.PROTOCOL_ERROR); + } + } + + private Task ProcessPingFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId != 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdNotZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 8) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(_incomingFrame.Type, 8), Http2ErrorCode.FRAME_SIZE_ERROR); + } + + if ((_incomingFrame.PingFlags & Http2PingFrameFlags.ACK) == Http2PingFrameFlags.ACK) + { + // TODO: verify that payload is equal to the outgoing PING frame + return Task.CompletedTask; + } + + return _frameWriter.WritePingAsync(Http2PingFrameFlags.ACK, _incomingFrame.Payload); + } + + private Task ProcessGoAwayFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId != 0) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdNotZero(_incomingFrame.Type), Http2ErrorCode.PROTOCOL_ERROR); + } + + StopProcessingNextRequest(); + return Task.CompletedTask; + } + + private Task ProcessWindowUpdateFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 4) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(_incomingFrame.Type, 4), Http2ErrorCode.FRAME_SIZE_ERROR); + } + + ThrowIfIncomingFrameSentToIdleStream(); + + if (_incomingFrame.WindowUpdateSizeIncrement == 0) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.6.9 + // A receiver MUST treat the receipt of a WINDOW_UPDATE + // frame with an flow-control window increment of 0 as a + // stream error (Section 5.4.2) of type PROTOCOL_ERROR; + // errors on the connection flow-control window MUST be + // treated as a connection error (Section 5.4.1). + // + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.4.1 + // An endpoint can end a connection at any time. In + // particular, an endpoint MAY choose to treat a stream + // error as a connection error. + // + // Since server initiated stream resets are not yet properly + // implemented and tested, we treat all zero length window + // increments as connection errors for now. + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorWindowUpdateIncrementZero, Http2ErrorCode.PROTOCOL_ERROR); + } + + return Task.CompletedTask; + } + + private Task ProcessContinuationFrameAsync(IHttpApplication application) + { + if (_currentHeadersStream == null) + { + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorContinuationWithNoHeaders, Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId != _currentHeadersStream.StreamId) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + var endHeaders = (_incomingFrame.ContinuationFlags & Http2ContinuationFrameFlags.END_HEADERS) == Http2ContinuationFrameFlags.END_HEADERS; + + if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers) + { + return DecodeTrailersAsync(endHeaders, _incomingFrame.Payload); + } + else + { + return DecodeHeadersAsync(application, endHeaders, _incomingFrame.Payload); + } + } + + private Task ProcessUnknownFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorHeadersInterleaved(_incomingFrame.Type, _incomingFrame.StreamId, _currentHeadersStream.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + + return Task.CompletedTask; + } + + private Task DecodeHeadersAsync(IHttpApplication application, bool endHeaders, Span payload) + { + try + { + _hpackDecoder.Decode(payload, endHeaders, handler: this); + + if (endHeaders) + { + StartStream(application); + ResetRequestHeaderParsingState(); + } + } + catch (Http2StreamErrorException ex) + { + Log.Http2StreamError(ConnectionId, ex); + ResetRequestHeaderParsingState(); + return _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode); + } + + return Task.CompletedTask; + } + + private Task DecodeTrailersAsync(bool endHeaders, Span payload) + { + _hpackDecoder.Decode(payload, endHeaders, handler: this); + + if (endHeaders) + { + var endStreamTask = _currentHeadersStream.OnDataAsync(Constants.EmptyData, endStream: true); + ResetRequestHeaderParsingState(); + return endStreamTask; + } + + return Task.CompletedTask; + } + + private void StartStream(IHttpApplication application) + { + if (!_isMethodConnect && (_parsedPseudoHeaderFields & _mandatoryRequestPseudoHeaderFields) != _mandatoryRequestPseudoHeaderFields) + { + // All HTTP/2 requests MUST include exactly one valid value for the :method, :scheme, and :path pseudo-header + // fields, unless it is a CONNECT request (Section 8.3). An HTTP request that omits mandatory pseudo-header + // fields is malformed (Section 8.1.2.6). + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorMissingMandatoryPseudoHeaderFields, Http2ErrorCode.PROTOCOL_ERROR); + } + + _streams[_incomingFrame.StreamId] = _currentHeadersStream; + _ = _currentHeadersStream.ProcessRequestsAsync(application); + } + + private void ResetRequestHeaderParsingState() + { + if (_requestHeaderParsingState != RequestHeaderParsingState.Trailers) + { + _highestOpenedStreamId = _currentHeadersStream.StreamId; + } + + _currentHeadersStream = null; + _requestHeaderParsingState = RequestHeaderParsingState.Ready; + _parsedPseudoHeaderFields = PseudoHeaderFields.None; + _isMethodConnect = false; + } + + private void ThrowIfIncomingFrameSentToIdleStream() + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1 + // 5.1. Stream states + // ... + // idle: + // ... + // Receiving any frame other than HEADERS or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + // + // If the stream ID in the incoming frame is higher than the highest opened stream ID so + // far, then the incoming frame's target stream is in the idle state, which is the implicit + // initial state for all streams. + if (_incomingFrame.StreamId > _highestOpenedStreamId) + { + throw new Http2ConnectionErrorException(CoreStrings.FormatHttp2ErrorStreamIdle(_incomingFrame.Type, _incomingFrame.StreamId), Http2ErrorCode.PROTOCOL_ERROR); + } + } + + void IHttp2StreamLifetimeHandler.OnStreamCompleted(int streamId) + { + _streams.TryRemove(streamId, out _); + } + + public void OnHeader(Span name, Span value) + { + ValidateHeader(name, value); + _currentHeadersStream.OnHeader(name, value); + } + + private void ValidateHeader(Span name, Span value) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.1 + if (IsPseudoHeaderField(name, out var headerField)) + { + if (_requestHeaderParsingState == RequestHeaderParsingState.Headers) + { + // All pseudo-header fields MUST appear in the header block before regular header fields. + // Any request or response that contains a pseudo-header field that appears in a header + // block after a regular header field MUST be treated as malformed (Section 8.1.2.6). + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorPseudoHeaderFieldAfterRegularHeaders, Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers) + { + // Pseudo-header fields MUST NOT appear in trailers. + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorTrailersContainPseudoHeaderField, Http2ErrorCode.PROTOCOL_ERROR); + } + + _requestHeaderParsingState = RequestHeaderParsingState.PseudoHeaderFields; + + if (headerField == PseudoHeaderFields.Unknown) + { + // Endpoints MUST treat a request or response that contains undefined or invalid pseudo-header + // fields as malformed (Section 8.1.2.6). + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorUnknownPseudoHeaderField, Http2ErrorCode.PROTOCOL_ERROR); + } + + if (headerField == PseudoHeaderFields.Status) + { + // Pseudo-header fields defined for requests MUST NOT appear in responses; pseudo-header fields + // defined for responses MUST NOT appear in requests. + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorResponsePseudoHeaderField, Http2ErrorCode.PROTOCOL_ERROR); + } + + if ((_parsedPseudoHeaderFields & headerField) == headerField) + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.3 + // All HTTP/2 requests MUST include exactly one valid value for the :method, :scheme, and :path pseudo-header fields + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorDuplicatePseudoHeaderField, Http2ErrorCode.PROTOCOL_ERROR); + } + + if (headerField == PseudoHeaderFields.Method) + { + _isMethodConnect = value.SequenceEqual(_connectBytes); + } + + _parsedPseudoHeaderFields |= headerField; + } + else if (_requestHeaderParsingState != RequestHeaderParsingState.Trailers) + { + _requestHeaderParsingState = RequestHeaderParsingState.Headers; + } + + if (IsConnectionSpecificHeaderField(name, value)) + { + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorConnectionSpecificHeaderField, Http2ErrorCode.PROTOCOL_ERROR); + } + + // http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2 + // A request or response containing uppercase header field names MUST be treated as malformed (Section 8.1.2.6). + for (var i = 0; i < name.Length; i++) + { + if (name[i] >= 65 && name[i] <= 90) + { + if (_requestHeaderParsingState == RequestHeaderParsingState.Trailers) + { + throw new Http2ConnectionErrorException(CoreStrings.Http2ErrorTrailerNameUppercase, Http2ErrorCode.PROTOCOL_ERROR); + } + else + { + throw new Http2StreamErrorException(_currentHeadersStream.StreamId, CoreStrings.Http2ErrorHeaderNameUppercase, Http2ErrorCode.PROTOCOL_ERROR); + } + } + } + } + + private bool IsPseudoHeaderField(Span name, out PseudoHeaderFields headerField) + { + headerField = PseudoHeaderFields.None; + + if (name.IsEmpty || name[0] != (byte)':') + { + return false; + } + + // Skip ':' + name = name.Slice(1); + + if (name.SequenceEqual(_pathBytes)) + { + headerField = PseudoHeaderFields.Path; + } + else if (name.SequenceEqual(_methodBytes)) + { + headerField = PseudoHeaderFields.Method; + } + else if (name.SequenceEqual(_schemeBytes)) + { + headerField = PseudoHeaderFields.Scheme; + } + else if (name.SequenceEqual(_statusBytes)) + { + headerField = PseudoHeaderFields.Status; + } + else if (name.SequenceEqual(_authorityBytes)) + { + headerField = PseudoHeaderFields.Authority; + } + else + { + headerField = PseudoHeaderFields.Unknown; + } + + return true; + } + + private static bool IsConnectionSpecificHeaderField(Span name, Span value) + { + return name.SequenceEqual(_connectionBytes) || (name.SequenceEqual(_teBytes) && !value.SequenceEqual(_trailersBytes)); + } + + void ITimeoutControl.SetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + void ITimeoutControl.ResetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + void ITimeoutControl.CancelTimeout() + { + } + + void ITimeoutControl.StartTimingReads() + { + } + + void ITimeoutControl.PauseTimingReads() + { + } + + void ITimeoutControl.ResumeTimingReads() + { + } + + void ITimeoutControl.StopTimingReads() + { + } + + void ITimeoutControl.BytesRead(long count) + { + } + + void ITimeoutControl.StartTimingWrite(long size) + { + } + + void ITimeoutControl.StopTimingWrite() + { + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionContext.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionContext.cs new file mode 100644 index 0000000000..ae9ceb3b70 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionContext.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Net; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2ConnectionContext + { + public string ConnectionId { get; set; } + public ServiceContext ServiceContext { get; set; } + public IFeatureCollection ConnectionFeatures { get; set; } + public MemoryPool MemoryPool { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + + public IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionErrorException.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionErrorException.cs new file mode 100644 index 0000000000..dd1314b1a5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ConnectionErrorException.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2ConnectionErrorException : Exception + { + public Http2ConnectionErrorException(string message, Http2ErrorCode errorCode) + : base($"HTTP/2 connection error ({errorCode}): {message}") + { + ErrorCode = errorCode; + } + + public Http2ErrorCode ErrorCode { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ContinuationFrameFlags.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ContinuationFrameFlags.cs new file mode 100644 index 0000000000..65e65bc0bc --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ContinuationFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2ContinuationFrameFlags : byte + { + NONE = 0x0, + END_HEADERS = 0x4, + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2DataFrameFlags.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2DataFrameFlags.cs new file mode 100644 index 0000000000..735a4aea30 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2DataFrameFlags.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2DataFrameFlags : byte + { + NONE = 0x0, + END_STREAM = 0x1, + PADDED = 0x8 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ErrorCode.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ErrorCode.cs new file mode 100644 index 0000000000..401350fb39 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2ErrorCode.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2ErrorCode : uint + { + NO_ERROR = 0x0, + PROTOCOL_ERROR = 0x1, + INTERNAL_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + SETTINGS_TIMEOUT = 0x4, + STREAM_CLOSED = 0x5, + FRAME_SIZE_ERROR = 0x6, + REFUSED_STREAM = 0x7, + CANCEL = 0x8, + COMPRESSION_ERROR = 0x9, + CONNECT_ERROR = 0xa, + ENHANCE_YOUR_CALM = 0xb, + INADEQUATE_SECURITY = 0xc, + HTTP_1_1_REQUIRED = 0xd, + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Continuation.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Continuation.cs new file mode 100644 index 0000000000..d599864b7b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Continuation.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2ContinuationFrameFlags ContinuationFlags + { + get => (Http2ContinuationFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PrepareContinuation(Http2ContinuationFrameFlags flags, int streamId) + { + Length = MinAllowedMaxFrameSize - HeaderLength; + Type = Http2FrameType.CONTINUATION; + ContinuationFlags = flags; + StreamId = streamId; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Data.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Data.cs new file mode 100644 index 0000000000..91f0edb72a --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Data.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2DataFrameFlags DataFlags + { + get => (Http2DataFrameFlags)Flags; + set => Flags = (byte)value; + } + + public bool DataHasPadding => (DataFlags & Http2DataFrameFlags.PADDED) == Http2DataFrameFlags.PADDED; + + public byte DataPadLength + { + get => DataHasPadding ? _data[PayloadOffset] : (byte)0; + set => _data[PayloadOffset] = value; + } + + public ArraySegment DataPayload => DataHasPadding + ? new ArraySegment(_data, PayloadOffset + 1, Length - DataPadLength - 1) + : new ArraySegment(_data, PayloadOffset, Length); + + public void PrepareData(int streamId, byte? padLength = null) + { + var padded = padLength != null; + + Length = MinAllowedMaxFrameSize; + Type = Http2FrameType.DATA; + DataFlags = padded ? Http2DataFrameFlags.PADDED : Http2DataFrameFlags.NONE; + StreamId = streamId; + + if (padded) + { + DataPadLength = padLength.Value; + Payload.Slice(Length - padLength.Value).Fill(0); + } + } + + private void DataTraceFrame(ILogger logger) + { + logger.LogTrace("'DATA' Frame. Flags = {DataFlags}, PadLength = {PadLength}, PayloadLength = {PayloadLength}", DataFlags, DataPadLength, DataPayload.Count); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.GoAway.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.GoAway.cs new file mode 100644 index 0000000000..3e430de8b0 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.GoAway.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int GoAwayLastStreamId + { + get => (Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]; + set + { + Payload[0] = (byte)((value >> 24) & 0xff); + Payload[1] = (byte)((value >> 16) & 0xff); + Payload[2] = (byte)((value >> 8) & 0xff); + Payload[3] = (byte)(value & 0xff); + } + } + + public Http2ErrorCode GoAwayErrorCode + { + get => (Http2ErrorCode)((Payload[4] << 24) | (Payload[5] << 16) | (Payload[6] << 16) | Payload[7]); + set + { + Payload[4] = (byte)(((uint)value >> 24) & 0xff); + Payload[5] = (byte)(((uint)value >> 16) & 0xff); + Payload[6] = (byte)(((uint)value >> 8) & 0xff); + Payload[7] = (byte)((uint)value & 0xff); + } + } + + public void PrepareGoAway(int lastStreamId, Http2ErrorCode errorCode) + { + Length = 8; + Type = Http2FrameType.GOAWAY; + Flags = 0; + StreamId = 0; + GoAwayLastStreamId = lastStreamId; + GoAwayErrorCode = errorCode; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Headers.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Headers.cs new file mode 100644 index 0000000000..52c20bde94 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Headers.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2HeadersFrameFlags HeadersFlags + { + get => (Http2HeadersFrameFlags)Flags; + set => Flags = (byte)value; + } + + public bool HeadersHasPadding => (HeadersFlags & Http2HeadersFrameFlags.PADDED) == Http2HeadersFrameFlags.PADDED; + + public byte HeadersPadLength + { + get => HeadersHasPadding ? _data[HeaderLength] : (byte)0; + set => _data[HeaderLength] = value; + } + + public bool HeadersHasPriority => (HeadersFlags & Http2HeadersFrameFlags.PRIORITY) == Http2HeadersFrameFlags.PRIORITY; + + public byte HeadersPriority + { + get => _data[HeadersPriorityOffset]; + set => _data[HeadersPriorityOffset] = value; + } + + private int HeadersPriorityOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0) + 4; + + public int HeadersStreamDependency + { + get + { + var offset = HeadersStreamDependencyOffset; + + return (int)((uint)((_data[offset] << 24) + | (_data[offset + 1] << 16) + | (_data[offset + 2] << 8) + | _data[offset + 3]) & 0x7fffffff); + } + set + { + var offset = HeadersStreamDependencyOffset; + + _data[offset] = (byte)((value & 0xff000000) >> 24); + _data[offset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[offset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[offset + 3] = (byte)(value & 0x000000ff); + } + } + + private int HeadersStreamDependencyOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0); + + public Span HeadersPayload => new Span(_data, HeadersPayloadOffset, HeadersPayloadLength); + + private int HeadersPayloadOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0) + (HeadersHasPriority ? 5 : 0); + + private int HeadersPayloadLength => Length - ((HeadersHasPadding ? 1 : 0) + (HeadersHasPriority ? 5 : 0)) - HeadersPadLength; + + public void PrepareHeaders(Http2HeadersFrameFlags flags, int streamId) + { + Length = MinAllowedMaxFrameSize - HeaderLength; + Type = Http2FrameType.HEADERS; + HeadersFlags = flags; + StreamId = streamId; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Ping.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Ping.cs new file mode 100644 index 0000000000..cbbdc88d41 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Ping.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2PingFrameFlags PingFlags + { + get => (Http2PingFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PreparePing(Http2PingFrameFlags flags) + { + Length = 8; + Type = Http2FrameType.PING; + PingFlags = flags; + StreamId = 0; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Priority.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Priority.cs new file mode 100644 index 0000000000..02f9bf02f9 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Priority.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int PriorityStreamDependency + { + get => ((_data[PayloadOffset] << 24) + | (_data[PayloadOffset + 1] << 16) + | (_data[PayloadOffset + 2] << 8) + | _data[PayloadOffset + 3]) & 0x7fffffff; + set + { + _data[PayloadOffset] = (byte)((value & 0x7f000000) >> 24); + _data[PayloadOffset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[PayloadOffset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[PayloadOffset + 3] = (byte)(value & 0x000000ff); + } + } + + + public bool PriorityIsExclusive + { + get => (_data[PayloadOffset] & 0x80000000) != 0; + set + { + if (value) + { + _data[PayloadOffset] |= 0x80; + } + else + { + _data[PayloadOffset] &= 0x7f; + } + } + } + + public byte PriorityWeight + { + get => _data[PayloadOffset + 4]; + set => _data[PayloadOffset] = value; + } + + + public void PreparePriority(int streamId, int streamDependency, bool exclusive, byte weight) + { + Length = 5; + Type = Http2FrameType.PRIORITY; + StreamId = streamId; + PriorityStreamDependency = streamDependency; + PriorityIsExclusive = exclusive; + PriorityWeight = weight; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.RstStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.RstStream.cs new file mode 100644 index 0000000000..8a0bcdfd6c --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.RstStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2ErrorCode RstStreamErrorCode + { + get => (Http2ErrorCode)((Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]); + set + { + Payload[0] = (byte)(((uint)value >> 24) & 0xff); + Payload[1] = (byte)(((uint)value >> 16) & 0xff); + Payload[2] = (byte)(((uint)value >> 8) & 0xff); + Payload[3] = (byte)((uint)value & 0xff); + } + } + + public void PrepareRstStream(int streamId, Http2ErrorCode errorCode) + { + Length = 4; + Type = Http2FrameType.RST_STREAM; + Flags = 0; + StreamId = streamId; + RstStreamErrorCode = errorCode; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Settings.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Settings.cs new file mode 100644 index 0000000000..04cc78b209 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.Settings.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2SettingsFrameFlags SettingsFlags + { + get => (Http2SettingsFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PrepareSettings(Http2SettingsFrameFlags flags, Http2PeerSettings settings = null) + { + var settingCount = settings?.Count() ?? 0; + + Length = 6 * settingCount; + Type = Http2FrameType.SETTINGS; + SettingsFlags = flags; + StreamId = 0; + + if (settings != null) + { + Span payload = Payload; + foreach (var setting in settings) + { + payload[0] = (byte)((ushort)setting.Parameter >> 8); + payload[1] = (byte)(ushort)setting.Parameter; + payload[2] = (byte)(setting.Value >> 24); + payload[3] = (byte)(setting.Value >> 16); + payload[4] = (byte)(setting.Value >> 8); + payload[5] = (byte)setting.Value; + payload = payload.Slice(6); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.WindowUpdate.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.WindowUpdate.cs new file mode 100644 index 0000000000..6958b376f5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.WindowUpdate.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int WindowUpdateSizeIncrement + { + get => ((Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]) & 0x7fffffff; + set + { + Payload[0] = (byte)(((uint)value >> 24) & 0x7f); + Payload[1] = (byte)(((uint)value >> 16) & 0xff); + Payload[2] = (byte)(((uint)value >> 8) & 0xff); + Payload[3] = (byte)((uint)value & 0xff); + } + } + + public void PrepareWindowUpdate(int streamId, int sizeIncrement) + { + Length = 4; + Type = Http2FrameType.WINDOW_UPDATE; + Flags = 0; + StreamId = streamId; + WindowUpdateSizeIncrement = sizeIncrement; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.cs new file mode 100644 index 0000000000..0d693c0ced --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Frame.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public const int MinAllowedMaxFrameSize = 16 * 1024; + public const int MaxAllowedMaxFrameSize = 16 * 1024 * 1024 - 1; + public const int HeaderLength = 9; + + private const int LengthOffset = 0; + private const int TypeOffset = 3; + private const int FlagsOffset = 4; + private const int StreamIdOffset = 5; + private const int PayloadOffset = 9; + + private readonly byte[] _data = new byte[HeaderLength + MinAllowedMaxFrameSize]; + + public Span Raw => new Span(_data, 0, HeaderLength + Length); + + public int Length + { + get => (_data[LengthOffset] << 16) | (_data[LengthOffset + 1] << 8) | _data[LengthOffset + 2]; + set + { + _data[LengthOffset] = (byte)((value & 0x00ff0000) >> 16); + _data[LengthOffset + 1] = (byte)((value & 0x0000ff00) >> 8); + _data[LengthOffset + 2] = (byte)(value & 0x000000ff); + } + } + + public Http2FrameType Type + { + get => (Http2FrameType)_data[TypeOffset]; + set + { + _data[TypeOffset] = (byte)value; + } + } + + public byte Flags + { + get => _data[FlagsOffset]; + set + { + _data[FlagsOffset] = (byte)value; + } + } + + public int StreamId + { + get => (int)((uint)((_data[StreamIdOffset] << 24) + | (_data[StreamIdOffset + 1] << 16) + | (_data[StreamIdOffset + 2] << 8) + | _data[StreamIdOffset + 3]) & 0x7fffffff); + set + { + _data[StreamIdOffset] = (byte)((value & 0xff000000) >> 24); + _data[StreamIdOffset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[StreamIdOffset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[StreamIdOffset + 3] = (byte)(value & 0x000000ff); + } + } + + public Span Payload => new Span(_data, PayloadOffset, Length); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameReader.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameReader.cs new file mode 100644 index 0000000000..6b2ab584db --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameReader.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public static class Http2FrameReader + { + public static bool ReadFrame(ReadOnlySequence readableBuffer, Http2Frame frame, out SequencePosition consumed, out SequencePosition examined) + { + consumed = readableBuffer.Start; + examined = readableBuffer.End; + + if (readableBuffer.Length < Http2Frame.HeaderLength) + { + return false; + } + + var headerSlice = readableBuffer.Slice(0, Http2Frame.HeaderLength); + headerSlice.CopyTo(frame.Raw); + + if (readableBuffer.Length < Http2Frame.HeaderLength + frame.Length) + { + return false; + } + + readableBuffer.Slice(Http2Frame.HeaderLength, frame.Length).CopyTo(frame.Payload); + consumed = examined = readableBuffer.GetPosition(Http2Frame.HeaderLength + frame.Length); + + return true; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameType.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameType.cs new file mode 100644 index 0000000000..a09272a6be --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameType.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2FrameType : byte + { + DATA = 0x0, + HEADERS = 0x1, + PRIORITY = 0x2, + RST_STREAM = 0x3, + SETTINGS = 0x4, + PUSH_PROMISE = 0x5, + PING = 0x6, + GOAWAY = 0x7, + WINDOW_UPDATE = 0x8, + CONTINUATION = 0x9 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs new file mode 100644 index 0000000000..e31d4b7f4f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs @@ -0,0 +1,223 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2FrameWriter : IHttp2FrameWriter + { + // Literal Header Field without Indexing - Indexed Name (Index 8 - :status) + private static readonly byte[] _continueBytes = new byte[] { 0x08, 0x03, (byte)'1', (byte)'0', (byte)'0' }; + + private readonly Http2Frame _outgoingFrame = new Http2Frame(); + private readonly object _writeLock = new object(); + private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); + private readonly PipeWriter _outputWriter; + private readonly PipeReader _outputReader; + + private bool _completed; + + public Http2FrameWriter(PipeWriter outputPipeWriter, PipeReader outputPipeReader) + { + _outputWriter = outputPipeWriter; + _outputReader = outputPipeReader; + } + + public void Abort(Exception ex) + { + lock (_writeLock) + { + if (_completed) + { + return; + } + + _completed = true; + _outputReader.CancelPendingRead(); + _outputWriter.Complete(ex); + } + } + + public Task FlushAsync(CancellationToken cancellationToken) + { + lock (_writeLock) + { + return WriteAsync(Constants.EmptyData); + } + } + + public Task Write100ContinueAsync(int streamId) + { + lock (_writeLock) + { + _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS, streamId); + _outgoingFrame.Length = _continueBytes.Length; + _continueBytes.CopyTo(_outgoingFrame.HeadersPayload); + + return WriteAsync(_outgoingFrame.Raw); + } + } + + public void WriteResponseHeaders(int streamId, int statusCode, IHeaderDictionary headers) + { + lock (_writeLock) + { + _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.NONE, streamId); + + var done = _hpackEncoder.BeginEncode(statusCode, EnumerateHeaders(headers), _outgoingFrame.Payload, out var payloadLength); + _outgoingFrame.Length = payloadLength; + + if (done) + { + _outgoingFrame.HeadersFlags = Http2HeadersFrameFlags.END_HEADERS; + } + + Append(_outgoingFrame.Raw); + + while (!done) + { + _outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + + done = _hpackEncoder.Encode(_outgoingFrame.Payload, out var length); + _outgoingFrame.Length = length; + + if (done) + { + _outgoingFrame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; + } + + Append(_outgoingFrame.Raw); + } + } + } + + public Task WriteDataAsync(int streamId, ReadOnlySpan data, CancellationToken cancellationToken) + => WriteDataAsync(streamId, data, endStream: false, cancellationToken: cancellationToken); + + public Task WriteDataAsync(int streamId, ReadOnlySpan data, bool endStream, CancellationToken cancellationToken) + { + var tasks = new List(); + + lock (_writeLock) + { + _outgoingFrame.PrepareData(streamId); + + while (data.Length > _outgoingFrame.Length) + { + data.Slice(0, _outgoingFrame.Length).CopyTo(_outgoingFrame.Payload); + data = data.Slice(_outgoingFrame.Length); + + tasks.Add(WriteAsync(_outgoingFrame.Raw, cancellationToken)); + } + + _outgoingFrame.Length = data.Length; + + if (endStream) + { + _outgoingFrame.DataFlags = Http2DataFrameFlags.END_STREAM; + } + + data.CopyTo(_outgoingFrame.Payload); + + tasks.Add(WriteAsync(_outgoingFrame.Raw, cancellationToken)); + + return Task.WhenAll(tasks); + } + } + + public Task WriteRstStreamAsync(int streamId, Http2ErrorCode errorCode) + { + lock (_writeLock) + { + _outgoingFrame.PrepareRstStream(streamId, errorCode); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteSettingsAsync(Http2PeerSettings settings) + { + lock (_writeLock) + { + // TODO: actually send settings + _outgoingFrame.PrepareSettings(Http2SettingsFrameFlags.NONE); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteSettingsAckAsync() + { + lock (_writeLock) + { + _outgoingFrame.PrepareSettings(Http2SettingsFrameFlags.ACK); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WritePingAsync(Http2PingFrameFlags flags, ReadOnlySpan payload) + { + lock (_writeLock) + { + _outgoingFrame.PreparePing(Http2PingFrameFlags.ACK); + payload.CopyTo(_outgoingFrame.Payload); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteGoAwayAsync(int lastStreamId, Http2ErrorCode errorCode) + { + lock (_writeLock) + { + _outgoingFrame.PrepareGoAway(lastStreamId, errorCode); + return WriteAsync(_outgoingFrame.Raw); + } + } + + // Must be called with _writeLock + private void Append(ReadOnlySpan data) + { + if (_completed) + { + return; + } + + _outputWriter.Write(data); + } + + // Must be called with _writeLock + private Task WriteAsync(ReadOnlySpan data, CancellationToken cancellationToken = default(CancellationToken)) + { + if (_completed) + { + return Task.CompletedTask; + } + + _outputWriter.Write(data); + return FlushAsync(_outputWriter, cancellationToken); + } + + private async Task FlushAsync(PipeWriter outputWriter, CancellationToken cancellationToken) + { + await outputWriter.FlushAsync(cancellationToken); + } + + private static IEnumerable> EnumerateHeaders(IHeaderDictionary headers) + { + foreach (var header in headers) + { + foreach (var value in header.Value) + { + yield return new KeyValuePair(header.Key, value); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2HeadersFrameFlags.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2HeadersFrameFlags.cs new file mode 100644 index 0000000000..564371e1be --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2HeadersFrameFlags.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2HeadersFrameFlags : byte + { + NONE = 0x0, + END_STREAM = 0x1, + END_HEADERS = 0x4, + PADDED = 0x8, + PRIORITY = 0x20 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs new file mode 100644 index 0000000000..b6ad3af161 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public abstract class Http2MessageBody : MessageBody + { + private readonly Http2Stream _context; + + protected Http2MessageBody(Http2Stream context) + : base(context) + { + _context = context; + } + + protected override void OnReadStarted() + { + // Produce 100-continue if no request body data for the stream has arrived yet. + if (!_context.RequestBodyStarted) + { + TryProduceContinue(); + } + } + + protected override Task OnConsumeAsync() => Task.CompletedTask; + + public override Task StopAsync() + { + _context.RequestBodyPipe.Reader.Complete(); + _context.RequestBodyPipe.Writer.Complete(); + return Task.CompletedTask; + } + + public static MessageBody For( + HttpRequestHeaders headers, + Http2Stream context) + { + if (context.EndStreamReceived) + { + return ZeroContentLengthClose; + } + + return new ForHttp2(context); + } + + private class ForHttp2 : Http2MessageBody + { + public ForHttp2(Http2Stream context) + : base(context) + { + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs new file mode 100644 index 0000000000..e701654d1e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -0,0 +1,58 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2OutputProducer : IHttpOutputProducer + { + private readonly int _streamId; + private readonly IHttp2FrameWriter _frameWriter; + + public Http2OutputProducer(int streamId, IHttp2FrameWriter frameWriter) + { + _streamId = streamId; + _frameWriter = frameWriter; + } + + public void Dispose() + { + } + + public void Abort(ConnectionAbortedException error) + { + // TODO: RST_STREAM? + } + + public Task WriteAsync(Func callback, T state) + { + throw new NotImplementedException(); + } + + public Task FlushAsync(CancellationToken cancellationToken) => _frameWriter.FlushAsync(cancellationToken); + + public Task Write100ContinueAsync(CancellationToken cancellationToken) => _frameWriter.Write100ContinueAsync(_streamId); + + public Task WriteDataAsync(ReadOnlySpan data, CancellationToken cancellationToken) + { + return _frameWriter.WriteDataAsync(_streamId, data, cancellationToken); + } + + public Task WriteStreamSuffixAsync(CancellationToken cancellationToken) + { + return _frameWriter.WriteDataAsync(_streamId, Constants.EmptyData, endStream: true, cancellationToken: cancellationToken); + } + + public void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders) + { + _frameWriter.WriteResponseHeaders(_streamId, statusCode, responseHeaders); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSetting.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSetting.cs new file mode 100644 index 0000000000..f21b3ca929 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSetting.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public struct Http2PeerSetting + { + public Http2PeerSetting(Http2SettingsParameter parameter, uint value) + { + Parameter = parameter; + Value = value; + } + + public Http2SettingsParameter Parameter { get; } + + public uint Value { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSettings.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSettings.cs new file mode 100644 index 0000000000..4bf4787435 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PeerSettings.cs @@ -0,0 +1,105 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2PeerSettings : IEnumerable + { + public const uint DefaultHeaderTableSize = 4096; + public const bool DefaultEnablePush = true; + public const uint DefaultMaxConcurrentStreams = uint.MaxValue; + public const uint DefaultInitialWindowSize = 65535; + public const uint DefaultMaxFrameSize = 16384; + public const uint DefaultMaxHeaderListSize = uint.MaxValue; + + public uint HeaderTableSize { get; set; } = DefaultHeaderTableSize; + + public bool EnablePush { get; set; } = DefaultEnablePush; + + public uint MaxConcurrentStreams { get; set; } = DefaultMaxConcurrentStreams; + + public uint InitialWindowSize { get; set; } = DefaultInitialWindowSize; + + public uint MaxFrameSize { get; set; } = DefaultMaxFrameSize; + + public uint MaxHeaderListSize { get; set; } = DefaultMaxHeaderListSize; + + public void ParseFrame(Http2Frame frame) + { + var settingsCount = frame.Length / 6; + + for (var i = 0; i < settingsCount; i++) + { + var offset = i * 6; + var id = (Http2SettingsParameter)((frame.Payload[offset] << 8) | frame.Payload[offset + 1]); + var value = (uint)((frame.Payload[offset + 2] << 24) + | (frame.Payload[offset + 3] << 16) + | (frame.Payload[offset + 4] << 8) + | frame.Payload[offset + 5]); + + switch (id) + { + case Http2SettingsParameter.SETTINGS_HEADER_TABLE_SIZE: + HeaderTableSize = value; + break; + case Http2SettingsParameter.SETTINGS_ENABLE_PUSH: + if (value != 0 && value != 1) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, + lowerBound: 0, + upperBound: 1); + } + + EnablePush = value == 1; + break; + case Http2SettingsParameter.SETTINGS_MAX_CONCURRENT_STREAMS: + MaxConcurrentStreams = value; + break; + case Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE: + if (value > int.MaxValue) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, + lowerBound: 0, + upperBound: int.MaxValue); + } + + InitialWindowSize = value; + break; + case Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE: + if (value < Http2Frame.MinAllowedMaxFrameSize || value > Http2Frame.MaxAllowedMaxFrameSize) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, + lowerBound: Http2Frame.MinAllowedMaxFrameSize, + upperBound: Http2Frame.MaxAllowedMaxFrameSize); + } + + MaxFrameSize = value; + break; + case Http2SettingsParameter.SETTINGS_MAX_HEADER_LIST_SIZE: + MaxHeaderListSize = value; + break; + default: + // http://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2 + // + // An endpoint that receives a SETTINGS frame with any unknown or unsupported identifier MUST ignore that setting. + break; + } + } + } + + public IEnumerator GetEnumerator() + { + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_HEADER_TABLE_SIZE, HeaderTableSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, EnablePush ? 1u : 0); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_CONCURRENT_STREAMS, MaxConcurrentStreams); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, InitialWindowSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, MaxFrameSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_HEADER_LIST_SIZE, MaxHeaderListSize); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PingFrameFlags.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PingFrameFlags.cs new file mode 100644 index 0000000000..da5163f7e7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2PingFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2PingFrameFlags : byte + { + NONE = 0x0, + ACK = 0x1 + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsFrameFlags.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsFrameFlags.cs new file mode 100644 index 0000000000..5b0b8666cd --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2SettingsFrameFlags : byte + { + NONE = 0x0, + ACK = 0x1, + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameter.cs new file mode 100644 index 0000000000..918422a4c2 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameter.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2SettingsParameter : ushort + { + SETTINGS_HEADER_TABLE_SIZE = 0x1, + SETTINGS_ENABLE_PUSH = 0x2, + SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, + SETTINGS_INITIAL_WINDOW_SIZE = 0x4, + SETTINGS_MAX_FRAME_SIZE = 0x5, + SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs new file mode 100644 index 0000000000..95db1c9d58 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2SettingsParameterOutOfRangeException : Exception + { + public Http2SettingsParameterOutOfRangeException(Http2SettingsParameter parameter, uint lowerBound, uint upperBound) + : base($"HTTP/2 SETTINGS parameter {parameter} must be set to a value between {lowerBound} and {upperBound}") + { + Parameter = parameter; + } + + public Http2SettingsParameter Parameter { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs new file mode 100644 index 0000000000..782a8ddf51 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Stream : IHttp2StreamIdFeature + { + int IHttp2StreamIdFeature.StreamId => _context.StreamId; + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs new file mode 100644 index 0000000000..0234879a37 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -0,0 +1,157 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Stream : HttpProtocol + { + private readonly Http2StreamContext _context; + + public Http2Stream(Http2StreamContext context) + : base(context) + { + _context = context; + + Output = new Http2OutputProducer(StreamId, _context.FrameWriter); + } + + public int StreamId => _context.StreamId; + + public bool RequestBodyStarted { get; private set; } + public bool EndStreamReceived { get; private set; } + + protected IHttp2StreamLifetimeHandler StreamLifetimeHandler => _context.StreamLifetimeHandler; + + public override bool IsUpgradableRequest => false; + + protected override void OnReset() + { + ResetIHttp2StreamIdFeature(); + } + + protected override void OnRequestProcessingEnded() + { + StreamLifetimeHandler.OnStreamCompleted(StreamId); + } + + protected override string CreateRequestId() + => StringUtilities.ConcatAsHexSuffix(ConnectionId, ':', (uint)StreamId); + + protected override MessageBody CreateMessageBody() + => Http2MessageBody.For(HttpRequestHeaders, this); + + protected override bool TryParseRequest(ReadResult result, out bool endConnection) + { + // We don't need any of the parameters because we don't implement BeginRead to actually + // do the reading from a pipeline, nor do we use endConnection to report connection-level errors. + + _httpVersion = Http.HttpVersion.Http2; + var methodText = RequestHeaders[":method"]; + Method = HttpUtilities.GetKnownMethod(methodText); + _methodText = methodText; + if (!string.Equals(RequestHeaders[":scheme"], Scheme, StringComparison.OrdinalIgnoreCase)) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidRequestLine); + } + + var path = RequestHeaders[":path"].ToString(); + var queryIndex = path.IndexOf('?'); + + Path = queryIndex == -1 ? path : path.Substring(0, queryIndex); + QueryString = queryIndex == -1 ? string.Empty : path.Substring(queryIndex); + RawTarget = path; + + // https://tools.ietf.org/html/rfc7230#section-5.4 + // A server MUST respond with a 400 (Bad Request) status code to any + // HTTP/1.1 request message that lacks a Host header field and to any + // request message that contains more than one Host header field or a + // Host header field with an invalid field-value. + + var authority = RequestHeaders[":authority"]; + var host = HttpRequestHeaders.HeaderHost; + if (authority.Count > 0) + { + // https://tools.ietf.org/html/rfc7540#section-8.1.2.3 + // An intermediary that converts an HTTP/2 request to HTTP/1.1 MUST + // create a Host header field if one is not present in a request by + // copying the value of the ":authority" pseudo - header field. + // + // We take this one step further, we don't want mismatched :authority + // and Host headers, replace Host if :authority is defined. + HttpRequestHeaders.HeaderHost = authority; + host = authority; + } + + // TODO: OPTIONS * requests? + // To ensure that the HTTP / 1.1 request line can be reproduced + // accurately, this pseudo - header field MUST be omitted when + // translating from an HTTP/ 1.1 request that has a request target in + // origin or asterisk form(see[RFC7230], Section 5.3). + // https://tools.ietf.org/html/rfc7230#section-5.3 + + if (host.Count <= 0) + { + BadHttpRequestException.Throw(RequestRejectionReason.MissingHostHeader); + } + else if (host.Count > 1) + { + BadHttpRequestException.Throw(RequestRejectionReason.MultipleHostHeaders); + } + + var hostText = host.ToString(); + HttpUtilities.ValidateHostHeader(hostText); + + endConnection = false; + return true; + } + + public async Task OnDataAsync(ArraySegment data, bool endStream) + { + // TODO: content-length accounting + // TODO: flow-control + + try + { + if (data.Count > 0) + { + RequestBodyPipe.Writer.Write(data); + + RequestBodyStarted = true; + await RequestBodyPipe.Writer.FlushAsync(); + } + + if (endStream) + { + EndStreamReceived = true; + RequestBodyPipe.Writer.Complete(); + } + } + catch (Exception ex) + { + RequestBodyPipe.Writer.Complete(ex); + } + } + + // TODO: The HTTP/2 tests expect the request and response streams to be aborted with + // non-ConnectionAbortedExceptions. The abortReasons can include things like + // Http2ConnectionErrorException which don't derive from IOException or + // OperationCanceledException. This is probably not a good idea. + public void Http2Abort(Exception abortReason) + { + _streams?.Abort(abortReason); + + OnInputOrOutputCompleted(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamContext.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamContext.cs new file mode 100644 index 0000000000..eea80103a7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamContext.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Net; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2StreamContext : IHttpProtocolContext + { + public string ConnectionId { get; set; } + public int StreamId { get; set; } + public ServiceContext ServiceContext { get; set; } + public IFeatureCollection ConnectionFeatures { get; set; } + public MemoryPool MemoryPool { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public IHttp2StreamLifetimeHandler StreamLifetimeHandler { get; set; } + public IHttp2FrameWriter FrameWriter { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamErrorException.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamErrorException.cs new file mode 100644 index 0000000000..2f63df1412 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamErrorException.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2StreamErrorException : Exception + { + public Http2StreamErrorException(int streamId, string message, Http2ErrorCode errorCode) + : base($"HTTP/2 stream ID {streamId} error ({errorCode}): {message}") + { + StreamId = streamId; + ErrorCode = errorCode; + } + + public int StreamId { get; } + + public Http2ErrorCode ErrorCode { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2FrameWriter.cs new file mode 100644 index 0000000000..aa7b23330b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2FrameWriter.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public interface IHttp2FrameWriter + { + void Abort(Exception error); + Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)); + Task Write100ContinueAsync(int streamId); + void WriteResponseHeaders(int streamId, int statusCode, IHeaderDictionary headers); + Task WriteDataAsync(int streamId, ReadOnlySpan data, CancellationToken cancellationToken); + Task WriteDataAsync(int streamId, ReadOnlySpan data, bool endStream, CancellationToken cancellationToken); + Task WriteRstStreamAsync(int streamId, Http2ErrorCode errorCode); + Task WriteSettingsAckAsync(); + Task WritePingAsync(Http2PingFrameFlags flags, ReadOnlySpan payload); + Task WriteGoAwayAsync(int lastStreamId, Http2ErrorCode errorCode); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2StreamLifetimeHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2StreamLifetimeHandler.cs new file mode 100644 index 0000000000..fcb9c89637 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/IHttp2StreamLifetimeHandler.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public interface IHttp2StreamLifetimeHandler + { + void OnStreamCompleted(int streamId); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpConnection.cs b/src/Servers/Kestrel/Core/src/Internal/HttpConnection.cs new file mode 100644 index 0000000000..d1a32c6a29 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/HttpConnection.cs @@ -0,0 +1,662 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class HttpConnection : ITimeoutControl, IConnectionTimeoutFeature + { + private static readonly ReadOnlyMemory Http2Id = new[] { (byte)'h', (byte)'2' }; + + private readonly HttpConnectionContext _context; + private readonly TaskCompletionSource _socketClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + private IList _adaptedConnections; + private IDuplexPipe _adaptedTransport; + + private readonly object _protocolSelectionLock = new object(); + private ProtocolSelectionState _protocolSelectionState = ProtocolSelectionState.Initializing; + private IRequestProcessor _requestProcessor; + private Http1Connection _http1Connection; + + private long _lastTimestamp; + private long _timeoutTimestamp = long.MaxValue; + private TimeoutAction _timeoutAction; + + private readonly object _readTimingLock = new object(); + private bool _readTimingEnabled; + private bool _readTimingPauseRequested; + private long _readTimingElapsedTicks; + private long _readTimingBytesRead; + + private readonly object _writeTimingLock = new object(); + private int _writeTimingWrites; + private long _writeTimingTimeoutTimestamp; + + private Task _lifetimeTask; + + public HttpConnection(HttpConnectionContext context) + { + _context = context; + } + + // For testing + internal HttpProtocol Http1Connection => _http1Connection; + internal IDebugger Debugger { get; set; } = DebuggerWrapper.Singleton; + + // For testing + internal bool RequestTimedOut { get; private set; } + + public string ConnectionId => _context.ConnectionId; + public IPEndPoint LocalEndPoint => _context.LocalEndPoint; + public IPEndPoint RemoteEndPoint => _context.RemoteEndPoint; + + private MemoryPool MemoryPool => _context.MemoryPool; + + // Internal for testing + internal PipeOptions AdaptedInputPipeOptions => new PipeOptions + ( + pool: MemoryPool, + readerScheduler: _context.ServiceContext.Scheduler, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: _context.ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + resumeWriterThreshold: _context.ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + ); + + internal PipeOptions AdaptedOutputPipeOptions => new PipeOptions + ( + pool: MemoryPool, + readerScheduler: PipeScheduler.Inline, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: _context.ServiceContext.ServerOptions.Limits.MaxResponseBufferSize ?? 0, + resumeWriterThreshold: _context.ServiceContext.ServerOptions.Limits.MaxResponseBufferSize ?? 0, + useSynchronizationContext: false, + minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize + ); + + private IKestrelTrace Log => _context.ServiceContext.Log; + + public Task StartRequestProcessing(IHttpApplication application) + { + return _lifetimeTask = ProcessRequestsAsync(application); + } + + private async Task ProcessRequestsAsync(IHttpApplication httpApplication) + { + try + { + // TODO: When we start tracking all connection middleware for shutdown, go back + // to logging connections tart and stop in ConnectionDispatcher so we get these + // logs for all connection middleware. + Log.ConnectionStart(ConnectionId); + KestrelEventSource.Log.ConnectionStart(this); + + AdaptedPipeline adaptedPipeline = null; + var adaptedPipelineTask = Task.CompletedTask; + + // _adaptedTransport must be set prior to adding the connection to the manager in order + // to allow the connection to be aported prior to protocol selection. + _adaptedTransport = _context.Transport; + var application = _context.Application; + + + if (_context.ConnectionAdapters.Count > 0) + { + adaptedPipeline = new AdaptedPipeline(_adaptedTransport, + new Pipe(AdaptedInputPipeOptions), + new Pipe(AdaptedOutputPipeOptions), + Log); + + _adaptedTransport = adaptedPipeline; + } + + // Do this before the first await so we don't yield control to the transport until we've + // added the connection to the connection manager + _context.ServiceContext.ConnectionManager.AddConnection(_context.HttpConnectionId, this); + _lastTimestamp = _context.ServiceContext.SystemClock.UtcNow.Ticks; + + _context.ConnectionFeatures.Set(this); + + if (adaptedPipeline != null) + { + // Stream can be null here and run async will close the connection in that case + var stream = await ApplyConnectionAdaptersAsync(); + adaptedPipelineTask = adaptedPipeline.RunAsync(stream); + } + + IRequestProcessor requestProcessor = null; + + lock (_protocolSelectionLock) + { + // Ensure that the connection hasn't already been stopped. + if (_protocolSelectionState == ProtocolSelectionState.Initializing) + { + switch (SelectProtocol()) + { + case HttpProtocols.Http1: + // _http1Connection must be initialized before adding the connection to the connection manager + requestProcessor = _http1Connection = CreateHttp1Connection(_adaptedTransport, application); + _protocolSelectionState = ProtocolSelectionState.Selected; + break; + case HttpProtocols.Http2: + // _http2Connection must be initialized before yielding control to the transport thread, + // to prevent a race condition where _http2Connection.Abort() is called just as + // _http2Connection is about to be initialized. + requestProcessor = CreateHttp2Connection(_adaptedTransport, application); + _protocolSelectionState = ProtocolSelectionState.Selected; + break; + case HttpProtocols.None: + // An error was already logged in SelectProtocol(), but we should close the connection. + Abort(ex: null); + break; + default: + // SelectProtocol() only returns Http1, Http2 or None. + throw new NotSupportedException($"{nameof(SelectProtocol)} returned something other than Http1, Http2 or None."); + } + + _requestProcessor = requestProcessor; + } + } + + if (requestProcessor != null) + { + await requestProcessor.ProcessRequestsAsync(httpApplication); + } + + await adaptedPipelineTask; + await _socketClosedTcs.Task; + } + catch (Exception ex) + { + Log.LogCritical(0, ex, $"Unexpected exception in {nameof(HttpConnection)}.{nameof(ProcessRequestsAsync)}."); + } + finally + { + _context.ServiceContext.ConnectionManager.RemoveConnection(_context.HttpConnectionId); + DisposeAdaptedConnections(); + + if (_http1Connection?.IsUpgraded == true) + { + _context.ServiceContext.ConnectionManager.UpgradedConnectionCount.ReleaseOne(); + } + + Log.ConnectionStop(ConnectionId); + KestrelEventSource.Log.ConnectionStop(this); + } + } + + // For testing only + internal void Initialize(IDuplexPipe transport, IDuplexPipe application) + { + _requestProcessor = _http1Connection = CreateHttp1Connection(transport, application); + _protocolSelectionState = ProtocolSelectionState.Selected; + } + + private Http1Connection CreateHttp1Connection(IDuplexPipe transport, IDuplexPipe application) + { + return new Http1Connection(new Http1ConnectionContext + { + ConnectionId = _context.ConnectionId, + ConnectionFeatures = _context.ConnectionFeatures, + MemoryPool = MemoryPool, + LocalEndPoint = LocalEndPoint, + RemoteEndPoint = RemoteEndPoint, + ServiceContext = _context.ServiceContext, + ConnectionContext = _context.ConnectionContext, + TimeoutControl = this, + Transport = transport, + Application = application + }); + } + + private Http2Connection CreateHttp2Connection(IDuplexPipe transport, IDuplexPipe application) + { + return new Http2Connection(new Http2ConnectionContext + { + ConnectionId = _context.ConnectionId, + ServiceContext = _context.ServiceContext, + ConnectionFeatures = _context.ConnectionFeatures, + MemoryPool = MemoryPool, + LocalEndPoint = LocalEndPoint, + RemoteEndPoint = RemoteEndPoint, + Application = application, + Transport = transport + }); + } + + public void OnConnectionClosed() + { + _socketClosedTcs.TrySetResult(null); + } + + public Task StopProcessingNextRequestAsync() + { + lock (_protocolSelectionLock) + { + switch (_protocolSelectionState) + { + case ProtocolSelectionState.Initializing: + CloseUninitializedConnection(abortReason: null); + _protocolSelectionState = ProtocolSelectionState.Aborted; + break; + case ProtocolSelectionState.Selected: + _requestProcessor.StopProcessingNextRequest(); + break; + case ProtocolSelectionState.Aborted: + break; + } + } + + return _lifetimeTask; + } + + public void OnInputOrOutputCompleted() + { + lock (_protocolSelectionLock) + { + switch (_protocolSelectionState) + { + case ProtocolSelectionState.Initializing: + CloseUninitializedConnection(abortReason: null); + _protocolSelectionState = ProtocolSelectionState.Aborted; + break; + case ProtocolSelectionState.Selected: + _requestProcessor.OnInputOrOutputCompleted(); + break; + case ProtocolSelectionState.Aborted: + break; + } + + } + } + + public void Abort(ConnectionAbortedException ex) + { + lock (_protocolSelectionLock) + { + switch (_protocolSelectionState) + { + case ProtocolSelectionState.Initializing: + CloseUninitializedConnection(ex); + break; + case ProtocolSelectionState.Selected: + _requestProcessor.Abort(ex); + break; + case ProtocolSelectionState.Aborted: + break; + } + + _protocolSelectionState = ProtocolSelectionState.Aborted; + } + } + + public Task AbortAsync(ConnectionAbortedException ex) + { + Abort(ex); + + return _socketClosedTcs.Task; + } + + private async Task ApplyConnectionAdaptersAsync() + { + var connectionAdapters = _context.ConnectionAdapters; + var stream = new RawStream(_context.Transport.Input, _context.Transport.Output); + var adapterContext = new ConnectionAdapterContext(_context.ConnectionContext, stream); + _adaptedConnections = new List(connectionAdapters.Count); + + try + { + for (var i = 0; i < connectionAdapters.Count; i++) + { + var adaptedConnection = await connectionAdapters[i].OnConnectionAsync(adapterContext); + _adaptedConnections.Add(adaptedConnection); + adapterContext = new ConnectionAdapterContext(_context.ConnectionContext, adaptedConnection.ConnectionStream); + } + } + catch (Exception ex) + { + Log.LogError(0, ex, $"Uncaught exception from the {nameof(IConnectionAdapter.OnConnectionAsync)} method of an {nameof(IConnectionAdapter)}."); + + return null; + } + + return adapterContext.ConnectionStream; + } + + private void DisposeAdaptedConnections() + { + var adaptedConnections = _adaptedConnections; + if (adaptedConnections != null) + { + for (var i = adaptedConnections.Count - 1; i >= 0; i--) + { + adaptedConnections[i].Dispose(); + } + } + } + + private HttpProtocols SelectProtocol() + { + var hasTls = _context.ConnectionFeatures.Get() != null; + var applicationProtocol = _context.ConnectionFeatures.Get()?.ApplicationProtocol + ?? new ReadOnlyMemory(); + var http1Enabled = (_context.Protocols & HttpProtocols.Http1) == HttpProtocols.Http1; + var http2Enabled = (_context.Protocols & HttpProtocols.Http2) == HttpProtocols.Http2; + + string error = null; + + if (_context.Protocols == HttpProtocols.None) + { + error = CoreStrings.EndPointRequiresAtLeastOneProtocol; + } + + if (!hasTls && http1Enabled && http2Enabled) + { + error = CoreStrings.EndPointRequiresTlsForHttp1AndHttp2; + } + + if (!http1Enabled && http2Enabled && hasTls && !Http2Id.Span.SequenceEqual(applicationProtocol.Span)) + { + error = CoreStrings.EndPointHttp2NotNegotiated; + } + + if (error != null) + { + Log.LogError(0, error); + return HttpProtocols.None; + } + + return http2Enabled && (!hasTls || Http2Id.Span.SequenceEqual(applicationProtocol.Span)) ? HttpProtocols.Http2 : HttpProtocols.Http1; + } + + public void Tick(DateTimeOffset now) + { + if (_protocolSelectionState == ProtocolSelectionState.Aborted) + { + // It's safe to check for timeouts on a dead connection, + // but try not to in order to avoid extraneous logs. + return; + } + + var timestamp = now.Ticks; + + CheckForTimeout(timestamp); + + // HTTP/2 rate timeouts are not yet supported. + if (_http1Connection != null) + { + CheckForReadDataRateTimeout(timestamp); + CheckForWriteDataRateTimeout(timestamp); + } + + Interlocked.Exchange(ref _lastTimestamp, timestamp); + } + + private void CheckForTimeout(long timestamp) + { + // TODO: Use PlatformApis.VolatileRead equivalent again + if (timestamp > Interlocked.Read(ref _timeoutTimestamp)) + { + if (!Debugger.IsAttached) + { + CancelTimeout(); + + switch (_timeoutAction) + { + case TimeoutAction.StopProcessingNextRequest: + // Http/2 keep-alive timeouts are not yet supported. + _http1Connection?.StopProcessingNextRequest(); + break; + case TimeoutAction.SendTimeoutResponse: + // HTTP/2 timeout responses are not yet supported. + if (_http1Connection != null) + { + RequestTimedOut = true; + _http1Connection.SendTimeoutResponse(); + } + break; + case TimeoutAction.AbortConnection: + // This is actually supported with HTTP/2! + Abort(new ConnectionAbortedException(CoreStrings.ConnectionTimedOutByServer)); + break; + } + } + } + } + + private void CheckForReadDataRateTimeout(long timestamp) + { + Debug.Assert(_http1Connection != null); + + // The only time when both a timeout is set and the read data rate could be enforced is + // when draining the request body. Since there's already a (short) timeout set for draining, + // it's safe to not check the data rate at this point. + if (Interlocked.Read(ref _timeoutTimestamp) != long.MaxValue) + { + return; + } + + lock (_readTimingLock) + { + if (_readTimingEnabled) + { + // Reference in local var to avoid torn reads in case the min rate is changed via IHttpMinRequestBodyDataRateFeature + var minRequestBodyDataRate = _http1Connection.MinRequestBodyDataRate; + + _readTimingElapsedTicks += timestamp - _lastTimestamp; + + if (minRequestBodyDataRate?.BytesPerSecond > 0 && _readTimingElapsedTicks > minRequestBodyDataRate.GracePeriod.Ticks) + { + var elapsedSeconds = (double)_readTimingElapsedTicks / TimeSpan.TicksPerSecond; + var rate = Interlocked.Read(ref _readTimingBytesRead) / elapsedSeconds; + + if (rate < minRequestBodyDataRate.BytesPerSecond && !Debugger.IsAttached) + { + Log.RequestBodyMininumDataRateNotSatisfied(_context.ConnectionId, _http1Connection.TraceIdentifier, minRequestBodyDataRate.BytesPerSecond); + RequestTimedOut = true; + _http1Connection.SendTimeoutResponse(); + } + } + + // PauseTimingReads() cannot just set _timingReads to false. It needs to go through at least one tick + // before pausing, otherwise _readTimingElapsed might never be updated if PauseTimingReads() is always + // called before the next tick. + if (_readTimingPauseRequested) + { + _readTimingEnabled = false; + _readTimingPauseRequested = false; + } + } + } + } + + private void CheckForWriteDataRateTimeout(long timestamp) + { + Debug.Assert(_http1Connection != null); + + lock (_writeTimingLock) + { + if (_writeTimingWrites > 0 && timestamp > _writeTimingTimeoutTimestamp && !Debugger.IsAttached) + { + RequestTimedOut = true; + Log.ResponseMininumDataRateNotSatisfied(_http1Connection.ConnectionIdFeature, _http1Connection.TraceIdentifier); + Abort(new ConnectionAbortedException(CoreStrings.ConnectionTimedBecauseResponseMininumDataRateNotSatisfied)); + } + } + } + + public void SetTimeout(long ticks, TimeoutAction timeoutAction) + { + Debug.Assert(_timeoutTimestamp == long.MaxValue, "Concurrent timeouts are not supported"); + + AssignTimeout(ticks, timeoutAction); + } + + public void ResetTimeout(long ticks, TimeoutAction timeoutAction) + { + AssignTimeout(ticks, timeoutAction); + } + + public void CancelTimeout() + { + Interlocked.Exchange(ref _timeoutTimestamp, long.MaxValue); + } + + private void AssignTimeout(long ticks, TimeoutAction timeoutAction) + { + _timeoutAction = timeoutAction; + + // Add Heartbeat.Interval since this can be called right before the next heartbeat. + Interlocked.Exchange(ref _timeoutTimestamp, _lastTimestamp + ticks + Heartbeat.Interval.Ticks); + } + + public void StartTimingReads() + { + lock (_readTimingLock) + { + _readTimingElapsedTicks = 0; + _readTimingBytesRead = 0; + _readTimingEnabled = true; + } + } + + public void StopTimingReads() + { + lock (_readTimingLock) + { + _readTimingEnabled = false; + } + } + + public void PauseTimingReads() + { + lock (_readTimingLock) + { + _readTimingPauseRequested = true; + } + } + + public void ResumeTimingReads() + { + lock (_readTimingLock) + { + _readTimingEnabled = true; + + // In case pause and resume were both called between ticks + _readTimingPauseRequested = false; + } + } + + public void BytesRead(long count) + { + Interlocked.Add(ref _readTimingBytesRead, count); + } + + public void StartTimingWrite(long size) + { + Debug.Assert(_http1Connection != null); + + lock (_writeTimingLock) + { + var minResponseDataRate = _http1Connection.MinResponseDataRate; + + if (minResponseDataRate != null) + { + // Add Heartbeat.Interval since this can be called right before the next heartbeat. + var currentTimeUpperBound = _lastTimestamp + Heartbeat.Interval.Ticks; + var ticksToCompleteWriteAtMinRate = TimeSpan.FromSeconds(size / minResponseDataRate.BytesPerSecond).Ticks; + + // If ticksToCompleteWriteAtMinRate is less than the configured grace period, + // allow that write to take up to the grace period to complete. Only add the grace period + // to the current time and not to any accumulated timeout. + var singleWriteTimeoutTimestamp = currentTimeUpperBound + Math.Max( + minResponseDataRate.GracePeriod.Ticks, + ticksToCompleteWriteAtMinRate); + + // Don't penalize a connection for completing previous writes more quickly than required. + // We don't want to kill a connection when flushing the chunk terminator just because the previous + // chunk was large if the previous chunk was flushed quickly. + + // Don't add any grace period to this accumulated timeout because the grace period could + // get accumulated repeatedly making the timeout for a bunch of consecutive small writes + // far too conservative. + var accumulatedWriteTimeoutTimestamp = _writeTimingTimeoutTimestamp + ticksToCompleteWriteAtMinRate; + + _writeTimingTimeoutTimestamp = Math.Max(singleWriteTimeoutTimestamp, accumulatedWriteTimeoutTimestamp); + _writeTimingWrites++; + } + } + } + + public void StopTimingWrite() + { + lock (_writeTimingLock) + { + _writeTimingWrites--; + } + } + + void IConnectionTimeoutFeature.SetTimeout(TimeSpan timeSpan) + { + if (timeSpan < TimeSpan.Zero) + { + throw new ArgumentException(CoreStrings.PositiveFiniteTimeSpanRequired, nameof(timeSpan)); + } + if (_timeoutTimestamp != long.MaxValue) + { + throw new InvalidOperationException(CoreStrings.ConcurrentTimeoutsNotSupported); + } + + SetTimeout(timeSpan.Ticks, TimeoutAction.AbortConnection); + } + + void IConnectionTimeoutFeature.ResetTimeout(TimeSpan timeSpan) + { + if (timeSpan < TimeSpan.Zero) + { + throw new ArgumentException(CoreStrings.PositiveFiniteTimeSpanRequired, nameof(timeSpan)); + } + + ResetTimeout(timeSpan.Ticks, TimeoutAction.AbortConnection); + } + + private void CloseUninitializedConnection(ConnectionAbortedException abortReason) + { + Debug.Assert(_adaptedTransport != null); + + _context.ConnectionContext.Abort(abortReason); + + _adaptedTransport.Input.Complete(); + _adaptedTransport.Output.Complete(); + } + + private enum ProtocolSelectionState + { + Initializing, + Selected, + Aborted + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpConnectionBuilderExtensions.cs b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionBuilderExtensions.cs new file mode 100644 index 0000000000..dcd073855b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionBuilderExtensions.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public static class HttpConnectionBuilderExtensions + { + public static IConnectionBuilder UseHttpServer(this IConnectionBuilder builder, ServiceContext serviceContext, IHttpApplication application, HttpProtocols protocols) + { + return builder.UseHttpServer(Array.Empty(), serviceContext, application, protocols); + } + + public static IConnectionBuilder UseHttpServer(this IConnectionBuilder builder, IList adapters, ServiceContext serviceContext, IHttpApplication application, HttpProtocols protocols) + { + var middleware = new HttpConnectionMiddleware(adapters, serviceContext, application, protocols); + return builder.Use(next => + { + return middleware.OnConnectionAsync; + }); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs new file mode 100644 index 0000000000..161ca647a7 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class HttpConnectionContext + { + public string ConnectionId { get; set; } + public long HttpConnectionId { get; set; } + public HttpProtocols Protocols { get; set; } + public ConnectionContext ConnectionContext { get; set; } + public ServiceContext ServiceContext { get; set; } + public IFeatureCollection ConnectionFeatures { get; set; } + public IList ConnectionAdapters { get; set; } + public MemoryPool MemoryPool { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + public IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpConnectionMiddleware.cs b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionMiddleware.cs new file mode 100644 index 0000000000..f7ffda4654 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionMiddleware.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class HttpConnectionMiddleware + { + private static long _lastHttpConnectionId = long.MinValue; + + private readonly IList _connectionAdapters; + private readonly ServiceContext _serviceContext; + private readonly IHttpApplication _application; + private readonly HttpProtocols _protocols; + + public HttpConnectionMiddleware(IList adapters, ServiceContext serviceContext, IHttpApplication application, HttpProtocols protocols) + { + _serviceContext = serviceContext; + _application = application; + _protocols = protocols; + + // Keeping these around for now so progress can be made without updating tests + _connectionAdapters = adapters; + } + + public async Task OnConnectionAsync(ConnectionContext connectionContext) + { + // We need the transport feature so that we can cancel the output reader that the transport is using + // This is a bit of a hack but it preserves the existing semantics + var applicationFeature = connectionContext.Features.Get(); + var memoryPoolFeature = connectionContext.Features.Get(); + + var httpConnectionId = Interlocked.Increment(ref _lastHttpConnectionId); + + var httpConnectionContext = new HttpConnectionContext + { + ConnectionId = connectionContext.ConnectionId, + ConnectionContext = connectionContext, + HttpConnectionId = httpConnectionId, + Protocols = _protocols, + ServiceContext = _serviceContext, + ConnectionFeatures = connectionContext.Features, + MemoryPool = memoryPoolFeature.MemoryPool, + ConnectionAdapters = _connectionAdapters, + Transport = connectionContext.Transport, + Application = applicationFeature.Application + }; + + var connectionFeature = connectionContext.Features.Get(); + var lifetimeFeature = connectionContext.Features.Get(); + + if (connectionFeature != null) + { + if (connectionFeature.LocalIpAddress != null) + { + httpConnectionContext.LocalEndPoint = new IPEndPoint(connectionFeature.LocalIpAddress, connectionFeature.LocalPort); + } + + if (connectionFeature.RemoteIpAddress != null) + { + httpConnectionContext.RemoteEndPoint = new IPEndPoint(connectionFeature.RemoteIpAddress, connectionFeature.RemotePort); + } + } + + var connection = new HttpConnection(httpConnectionContext); + + var processingTask = connection.StartRequestProcessing(_application); + + connectionContext.Transport.Input.OnWriterCompleted( + (_, state) => ((HttpConnection)state).OnInputOrOutputCompleted(), + connection); + + connectionContext.Transport.Output.OnReaderCompleted( + (_, state) => ((HttpConnection)state).OnInputOrOutputCompleted(), + connection); + + await CancellationTokenAsTask(lifetimeFeature.ConnectionClosed); + + connection.OnConnectionClosed(); + + await processingTask; + } + + private static Task CancellationTokenAsTask(CancellationToken token) + { + if (token.IsCancellationRequested) + { + return Task.CompletedTask; + } + + // Transports already dispatch prior to tripping ConnectionClosed + // since application code can register to this token. + var tcs = new TaskCompletionSource(); + token.Register(() => tcs.SetResult(null)); + return tcs.Task; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpsConnectionAdapter.cs b/src/Servers/Kestrel/Core/src/Internal/HttpsConnectionAdapter.cs new file mode 100644 index 0000000000..95ee435b43 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/HttpsConnectionAdapter.cs @@ -0,0 +1,259 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Https.Internal +{ + public class HttpsConnectionAdapter : IConnectionAdapter + { + private static readonly ClosedAdaptedConnection _closedAdaptedConnection = new ClosedAdaptedConnection(); + + private readonly HttpsConnectionAdapterOptions _options; + private readonly X509Certificate2 _serverCertificate; + private readonly Func _serverCertificateSelector; + + private readonly ILogger _logger; + + public HttpsConnectionAdapter(HttpsConnectionAdapterOptions options) + : this(options, loggerFactory: null) + { + } + + public HttpsConnectionAdapter(HttpsConnectionAdapterOptions options, ILoggerFactory loggerFactory) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + // capture the certificate now so it can't be switched after validation + _serverCertificate = options.ServerCertificate; + _serverCertificateSelector = options.ServerCertificateSelector; + if (_serverCertificate == null && _serverCertificateSelector == null) + { + throw new ArgumentException(CoreStrings.ServerCertificateRequired, nameof(options)); + } + + // If a selector is provided then ignore the cert, it may be a default cert. + if (_serverCertificateSelector != null) + { + // SslStream doesn't allow both. + _serverCertificate = null; + } + else + { + EnsureCertificateIsAllowedForServerAuth(_serverCertificate); + } + + _options = options; + _logger = loggerFactory?.CreateLogger(nameof(HttpsConnectionAdapter)); + } + + public bool IsHttps => true; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + // Don't trust SslStream not to block. + return Task.Run(() => InnerOnConnectionAsync(context)); + } + + private async Task InnerOnConnectionAsync(ConnectionAdapterContext context) + { + SslStream sslStream; + bool certificateRequired; + var feature = new TlsConnectionFeature(); + context.Features.Set(feature); + + if (_options.ClientCertificateMode == ClientCertificateMode.NoCertificate) + { + sslStream = new SslStream(context.ConnectionStream); + certificateRequired = false; + } + else + { + sslStream = new SslStream(context.ConnectionStream, + leaveInnerStreamOpen: false, + userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => + { + if (certificate == null) + { + return _options.ClientCertificateMode != ClientCertificateMode.RequireCertificate; + } + + if (_options.ClientCertificateValidation == null) + { + if (sslPolicyErrors != SslPolicyErrors.None) + { + return false; + } + } + + var certificate2 = ConvertToX509Certificate2(certificate); + if (certificate2 == null) + { + return false; + } + + if (_options.ClientCertificateValidation != null) + { + if (!_options.ClientCertificateValidation(certificate2, chain, sslPolicyErrors)) + { + return false; + } + } + + return true; + }); + + certificateRequired = true; + } + + var timeoutFeature = context.Features.Get(); + timeoutFeature.SetTimeout(_options.HandshakeTimeout); + + try + { +#if NETCOREAPP2_1 + // Adapt to the SslStream signature + ServerCertificateSelectionCallback selector = null; + if (_serverCertificateSelector != null) + { + selector = (sender, name) => + { + context.Features.Set(sslStream); + var cert = _serverCertificateSelector(context.ConnectionContext, name); + if (cert != null) + { + EnsureCertificateIsAllowedForServerAuth(cert); + } + return cert; + }; + } + + var sslOptions = new SslServerAuthenticationOptions() + { + ServerCertificate = _serverCertificate, + ServerCertificateSelectionCallback = selector, + ClientCertificateRequired = certificateRequired, + EnabledSslProtocols = _options.SslProtocols, + CertificateRevocationCheckMode = _options.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + ApplicationProtocols = new List() + }; + + // This is order sensitive + if ((_options.HttpProtocols & HttpProtocols.Http2) != 0) + { + sslOptions.ApplicationProtocols.Add(SslApplicationProtocol.Http2); + } + + if ((_options.HttpProtocols & HttpProtocols.Http1) != 0) + { + sslOptions.ApplicationProtocols.Add(SslApplicationProtocol.Http11); + } + + await sslStream.AuthenticateAsServerAsync(sslOptions, CancellationToken.None); +#else + var serverCert = _serverCertificate; + if (_serverCertificateSelector != null) + { + context.Features.Set(sslStream); + serverCert = _serverCertificateSelector(context.ConnectionContext, null); + if (serverCert != null) + { + EnsureCertificateIsAllowedForServerAuth(serverCert); + } + } + await sslStream.AuthenticateAsServerAsync(serverCert, certificateRequired, + _options.SslProtocols, _options.CheckCertificateRevocation); +#endif + } + catch (OperationCanceledException) + { + _logger?.LogDebug(2, CoreStrings.AuthenticationTimedOut); + sslStream.Dispose(); + return _closedAdaptedConnection; + } + catch (Exception ex) when (ex is IOException || ex is AuthenticationException) + { + _logger?.LogDebug(1, ex, CoreStrings.AuthenticationFailed); + sslStream.Dispose(); + return _closedAdaptedConnection; + } + finally + { + timeoutFeature.CancelTimeout(); + } + +#if NETCOREAPP2_1 + feature.ApplicationProtocol = sslStream.NegotiatedApplicationProtocol.Protocol; + context.Features.Set(feature); +#endif + feature.ClientCertificate = ConvertToX509Certificate2(sslStream.RemoteCertificate); + + return new HttpsAdaptedConnection(sslStream); + } + + private static void EnsureCertificateIsAllowedForServerAuth(X509Certificate2 certificate) + { + if (!CertificateLoader.IsCertificateAllowedForServerAuth(certificate)) + { + throw new InvalidOperationException(CoreStrings.FormatInvalidServerCertificateEku(certificate.Thumbprint)); + } + } + + private static X509Certificate2 ConvertToX509Certificate2(X509Certificate certificate) + { + if (certificate == null) + { + return null; + } + + if (certificate is X509Certificate2 cert2) + { + return cert2; + } + + return new X509Certificate2(certificate); + } + + private class HttpsAdaptedConnection : IAdaptedConnection + { + private readonly SslStream _sslStream; + + public HttpsAdaptedConnection(SslStream sslStream) + { + _sslStream = sslStream; + } + + public Stream ConnectionStream => _sslStream; + + public void Dispose() + { + _sslStream.Dispose(); + } + } + + private class ClosedAdaptedConnection : IAdaptedConnection + { + public Stream ConnectionStream { get; } = new ClosedStream(); + + public void Dispose() + { + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/IRequestProcessor.cs b/src/Servers/Kestrel/Core/src/Internal/IRequestProcessor.cs new file mode 100644 index 0000000000..d7385f9b96 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/IRequestProcessor.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting.Server; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public interface IRequestProcessor + { + Task ProcessRequestsAsync(IHttpApplication application); + void StopProcessingNextRequest(); + void OnInputOrOutputCompleted(); + void Abort(ConnectionAbortedException ex); + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CancellationTokenExtensions.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CancellationTokenExtensions.cs new file mode 100644 index 0000000000..c5d0392f00 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CancellationTokenExtensions.cs @@ -0,0 +1,76 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal static class CancellationTokenExtensions + { + public static IDisposable SafeRegister(this CancellationToken cancellationToken, Action callback, object state) + { + var callbackWrapper = new CancellationCallbackWrapper(callback, state); + var registration = cancellationToken.Register(s => InvokeCallback(s), callbackWrapper); + var disposeCancellationState = new DisposeCancellationState(callbackWrapper, registration); + + return new DisposableAction(s => Dispose(s), disposeCancellationState); + } + + private static void InvokeCallback(object state) + { + ((CancellationCallbackWrapper)state).TryInvoke(); + } + + private static void Dispose(object state) + { + ((DisposeCancellationState)state).TryDispose(); + } + + private class DisposeCancellationState + { + private readonly CancellationCallbackWrapper _callbackWrapper; + private readonly CancellationTokenRegistration _registration; + + public DisposeCancellationState(CancellationCallbackWrapper callbackWrapper, CancellationTokenRegistration registration) + { + _callbackWrapper = callbackWrapper; + _registration = registration; + } + + public void TryDispose() + { + if (_callbackWrapper.TrySetInvoked()) + { + _registration.Dispose(); + } + } + } + + private class CancellationCallbackWrapper + { + private readonly Action _callback; + private readonly object _state; + private int _callbackInvoked; + + public CancellationCallbackWrapper(Action callback, object state) + { + _callback = callback; + _state = state; + } + + public bool TrySetInvoked() + { + return Interlocked.Exchange(ref _callbackInvoked, 1) == 0; + } + + public void TryInvoke() + { + if (TrySetInvoked()) + { + _callback(_state); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Constants.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Constants.cs new file mode 100644 index 0000000000..8aead7103c --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Constants.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal static class Constants + { + public const int MaxExceptionDetailSize = 128; + + /// + /// The endpoint Kestrel will bind to if nothing else is specified. + /// + public static readonly string DefaultServerAddress = "http://localhost:5000"; + + /// + /// The endpoint Kestrel will bind to if nothing else is specified and a default certificate is available. + /// + public static readonly string DefaultServerHttpsAddress = "https://localhost:5001"; + + /// + /// Prefix of host name used to specify Unix sockets in the configuration. + /// + public const string UnixPipeHostPrefix = "unix:/"; + + /// + /// Prefix of host name used to specify pipe file descriptor in the configuration. + /// + public const string PipeDescriptorPrefix = "pipefd:"; + + /// + /// Prefix of host name used to specify socket descriptor in the configuration. + /// + public const string SocketDescriptorPrefix = "sockfd:"; + + public const string ServerName = "Kestrel"; + + public static readonly TimeSpan RequestBodyDrainTimeout = TimeSpan.FromSeconds(5); + + public static readonly ArraySegment EmptyData = new ArraySegment(new byte[0]); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CorrelationIdGenerator.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CorrelationIdGenerator.cs new file mode 100644 index 0000000000..fc161d4116 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/CorrelationIdGenerator.cs @@ -0,0 +1,48 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal static class CorrelationIdGenerator + { + // Base32 encoding - in ascii sort order for easy text based sorting + private static readonly string _encode32Chars = "0123456789ABCDEFGHIJKLMNOPQRSTUV"; + + // Seed the _lastConnectionId for this application instance with + // the number of 100-nanosecond intervals that have elapsed since 12:00:00 midnight, January 1, 0001 + // for a roughly increasing _lastId over restarts + private static long _lastId = DateTime.UtcNow.Ticks; + + public static string GetNextId() => GenerateId(Interlocked.Increment(ref _lastId)); + + private static unsafe string GenerateId(long id) + { + // The following routine is ~310% faster than calling long.ToString() on x64 + // and ~600% faster than calling long.ToString() on x86 in tight loops of 1 million+ iterations + // See: https://github.com/aspnet/Hosting/pull/385 + + // stackalloc to allocate array on stack rather than heap + char* charBuffer = stackalloc char[13]; + + charBuffer[0] = _encode32Chars[(int)(id >> 60) & 31]; + charBuffer[1] = _encode32Chars[(int)(id >> 55) & 31]; + charBuffer[2] = _encode32Chars[(int)(id >> 50) & 31]; + charBuffer[3] = _encode32Chars[(int)(id >> 45) & 31]; + charBuffer[4] = _encode32Chars[(int)(id >> 40) & 31]; + charBuffer[5] = _encode32Chars[(int)(id >> 35) & 31]; + charBuffer[6] = _encode32Chars[(int)(id >> 30) & 31]; + charBuffer[7] = _encode32Chars[(int)(id >> 25) & 31]; + charBuffer[8] = _encode32Chars[(int)(id >> 20) & 31]; + charBuffer[9] = _encode32Chars[(int)(id >> 15) & 31]; + charBuffer[10] = _encode32Chars[(int)(id >> 10) & 31]; + charBuffer[11] = _encode32Chars[(int)(id >> 5) & 31]; + charBuffer[12] = _encode32Chars[(int)id & 31]; + + // string ctor overload that takes char* + return new string(charBuffer, 0, 13); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DebuggerWrapper.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DebuggerWrapper.cs new file mode 100644 index 0000000000..df2b2644d9 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DebuggerWrapper.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal sealed class DebuggerWrapper : IDebugger + { + private DebuggerWrapper() + { } + + public static IDebugger Singleton { get; } = new DebuggerWrapper(); + + public bool IsAttached => Debugger.IsAttached; + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Disposable.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Disposable.cs new file mode 100644 index 0000000000..620e749fe6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Disposable.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + /// + /// Summary description for Disposable + /// + public class Disposable : IDisposable + { + private Action _dispose; + private bool _disposedValue = false; // To detect redundant calls + + public Disposable(Action dispose) + { + _dispose = dispose; + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _dispose.Invoke(); + } + + _dispose = null; + _disposedValue = true; + } + } + + // This code added to correctly implement the disposable pattern. + public void Dispose() + { + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + Dispose(true); + GC.SuppressFinalize(this); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DisposableAction.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DisposableAction.cs new file mode 100644 index 0000000000..ff65931e24 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/DisposableAction.cs @@ -0,0 +1,40 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class DisposableAction : IDisposable + { + public static readonly DisposableAction Empty = new DisposableAction(() => { }); + + private Action _action; + private readonly object _state; + + public DisposableAction(Action action) + : this(state => ((Action)state).Invoke(), state: action) + { + } + + public DisposableAction(Action action, object state) + { + _action = action; + _state = state; + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + Interlocked.Exchange(ref _action, (state) => { }).Invoke(_state); + } + } + + public void Dispose() + { + Dispose(true); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Heartbeat.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Heartbeat.cs new file mode 100644 index 0000000000..fb0f17d83b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Heartbeat.cs @@ -0,0 +1,83 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class Heartbeat : IDisposable + { + public static readonly TimeSpan Interval = TimeSpan.FromSeconds(1); + + private readonly IHeartbeatHandler[] _callbacks; + private readonly ISystemClock _systemClock; + private readonly IDebugger _debugger; + private readonly IKestrelTrace _trace; + private readonly TimeSpan _interval; + private Timer _timer; + private int _executingOnHeartbeat; + + public Heartbeat(IHeartbeatHandler[] callbacks, ISystemClock systemClock, IDebugger debugger, IKestrelTrace trace): this(callbacks, systemClock, debugger, trace, Interval) + { + + } + + internal Heartbeat(IHeartbeatHandler[] callbacks, ISystemClock systemClock, IDebugger debugger, IKestrelTrace trace, TimeSpan interval) + { + _callbacks = callbacks; + _systemClock = systemClock; + _debugger = debugger; + _trace = trace; + _interval = interval; + } + + public void Start() + { + _timer = new Timer(OnHeartbeat, state: this, dueTime: _interval, period: _interval); + } + + private static void OnHeartbeat(object state) + { + ((Heartbeat)state).OnHeartbeat(); + } + + // Called by the Timer (background) thread + internal void OnHeartbeat() + { + var now = _systemClock.UtcNow; + + if (Interlocked.Exchange(ref _executingOnHeartbeat, 1) == 0) + { + try + { + foreach (var callback in _callbacks) + { + callback.OnHeartbeat(now); + } + } + catch (Exception ex) + { + _trace.LogError(0, ex, $"{nameof(Heartbeat)}.{nameof(OnHeartbeat)}"); + } + finally + { + Interlocked.Exchange(ref _executingOnHeartbeat, 0); + } + } + else + { + if (!_debugger.IsAttached) + { + _trace.HeartbeatSlow(_interval, now); + } + } + } + + public void Dispose() + { + _timer?.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManager.cs new file mode 100644 index 0000000000..feb3a4cb29 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManager.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class HttpConnectionManager + { + private readonly ConcurrentDictionary _connectionReferences = new ConcurrentDictionary(); + private readonly IKestrelTrace _trace; + + public HttpConnectionManager(IKestrelTrace trace, long? upgradedConnectionLimit) + : this(trace, GetCounter(upgradedConnectionLimit)) + { + } + + public HttpConnectionManager(IKestrelTrace trace, ResourceCounter upgradedConnections) + { + UpgradedConnectionCount = upgradedConnections; + _trace = trace; + } + + /// + /// Connections that have been switched to a different protocol. + /// + public ResourceCounter UpgradedConnectionCount { get; } + + public void AddConnection(long id, HttpConnection connection) + { + if (!_connectionReferences.TryAdd(id, new HttpConnectionReference(connection))) + { + throw new ArgumentException(nameof(id)); + } + } + + public void RemoveConnection(long id) + { + if (!_connectionReferences.TryRemove(id, out _)) + { + throw new ArgumentException(nameof(id)); + } + } + + public void Walk(Action callback) + { + foreach (var kvp in _connectionReferences) + { + var reference = kvp.Value; + + if (reference.TryGetConnection(out var connection)) + { + callback(connection); + } + else if (_connectionReferences.TryRemove(kvp.Key, out reference)) + { + // It's safe to modify the ConcurrentDictionary in the foreach. + // The connection reference has become unrooted because the application never completed. + _trace.ApplicationNeverCompleted(reference.ConnectionId); + } + + // If both conditions are false, the connection was removed during the heartbeat. + } + } + + private static ResourceCounter GetCounter(long? number) + => number.HasValue + ? ResourceCounter.Quota(number.Value) + : ResourceCounter.Unlimited; + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManagerShutdownExtensions.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManagerShutdownExtensions.cs new file mode 100644 index 0000000000..1b601c919b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionManagerShutdownExtensions.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public static class HttpConnectionManagerShutdownExtensions + { + public static async Task CloseAllConnectionsAsync(this HttpConnectionManager connectionManager, CancellationToken token) + { + var closeTasks = new List(); + + connectionManager.Walk(connection => + { + closeTasks.Add(connection.StopProcessingNextRequestAsync()); + }); + + var allClosedTask = Task.WhenAll(closeTasks.ToArray()); + return await Task.WhenAny(allClosedTask, CancellationTokenAsTask(token)).ConfigureAwait(false) == allClosedTask; + } + + public static async Task AbortAllConnectionsAsync(this HttpConnectionManager connectionManager) + { + var abortTasks = new List(); + var canceledException = new ConnectionAbortedException(CoreStrings.ConnectionAbortedDuringServerShutdown); + + connectionManager.Walk(connection => + { + abortTasks.Add(connection.AbortAsync(canceledException)); + }); + + var allAbortedTask = Task.WhenAll(abortTasks.ToArray()); + return await Task.WhenAny(allAbortedTask, Task.Delay(1000)).ConfigureAwait(false) == allAbortedTask; + } + + private static Task CancellationTokenAsTask(CancellationToken token) + { + if (token.IsCancellationRequested) + { + return Task.CompletedTask; + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + token.Register(() => tcs.SetResult(null)); + return tcs.Task; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionReference.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionReference.cs new file mode 100644 index 0000000000..395b58a9c4 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpConnectionReference.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class HttpConnectionReference + { + private readonly WeakReference _weakReference; + + public HttpConnectionReference(HttpConnection connection) + { + _weakReference = new WeakReference(connection); + ConnectionId = connection.ConnectionId; + } + + public string ConnectionId { get; } + + public bool TryGetConnection(out HttpConnection connection) + { + return _weakReference.TryGetTarget(out connection); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpHeartbeatManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpHeartbeatManager.cs new file mode 100644 index 0000000000..fba8e8e1a1 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpHeartbeatManager.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class HttpHeartbeatManager : IHeartbeatHandler + { + private readonly HttpConnectionManager _connectionManager; + private readonly Action _walkCallback; + private DateTimeOffset _now; + + public HttpHeartbeatManager(HttpConnectionManager connectionManager) + { + _connectionManager = connectionManager; + _walkCallback = WalkCallback; + } + + public void OnHeartbeat(DateTimeOffset now) + { + _now = now; + _connectionManager.Walk(_walkCallback); + } + + private void WalkCallback(HttpConnection connection) + { + connection.Tick(_now); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.Generated.cs new file mode 100644 index 0000000000..15e3e1cd7b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.Generated.cs @@ -0,0 +1,74 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public static partial class HttpUtilities + { + // readonly primitive statics can be Jit'd to consts https://github.com/dotnet/coreclr/issues/1079 + private static readonly ulong _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT "); + private static readonly ulong _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE \0"); + private static readonly ulong _httpHeadMethodLong = GetAsciiStringAsLong("HEAD \0\0\0"); + private static readonly ulong _httpPatchMethodLong = GetAsciiStringAsLong("PATCH \0\0"); + private static readonly ulong _httpPostMethodLong = GetAsciiStringAsLong("POST \0\0\0"); + private static readonly ulong _httpPutMethodLong = GetAsciiStringAsLong("PUT \0\0\0\0"); + private static readonly ulong _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS "); + private static readonly ulong _httpTraceMethodLong = GetAsciiStringAsLong("TRACE \0\0"); + + private static readonly ulong _mask8Chars = GetMaskAsLong(new byte[] + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); + + private static readonly ulong _mask7Chars = GetMaskAsLong(new byte[] + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00}); + + private static readonly ulong _mask6Chars = GetMaskAsLong(new byte[] + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00}); + + private static readonly ulong _mask5Chars = GetMaskAsLong(new byte[] + {0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00}); + + private static readonly ulong _mask4Chars = GetMaskAsLong(new byte[] + {0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}); + + private static readonly Tuple[] _knownMethods = + new Tuple[17]; + + private static readonly string[] _methodNames = new string[9]; + + static HttpUtilities() + { + SetKnownMethod(_mask4Chars, _httpPutMethodLong, HttpMethod.Put, 3); + SetKnownMethod(_mask5Chars, _httpHeadMethodLong, HttpMethod.Head, 4); + SetKnownMethod(_mask5Chars, _httpPostMethodLong, HttpMethod.Post, 4); + SetKnownMethod(_mask6Chars, _httpPatchMethodLong, HttpMethod.Patch, 5); + SetKnownMethod(_mask6Chars, _httpTraceMethodLong, HttpMethod.Trace, 5); + SetKnownMethod(_mask7Chars, _httpDeleteMethodLong, HttpMethod.Delete, 6); + SetKnownMethod(_mask8Chars, _httpConnectMethodLong, HttpMethod.Connect, 7); + SetKnownMethod(_mask8Chars, _httpOptionsMethodLong, HttpMethod.Options, 7); + FillKnownMethodsGaps(); + InitializeHostCharValidity(); + _methodNames[(byte)HttpMethod.Connect] = HttpMethods.Connect; + _methodNames[(byte)HttpMethod.Delete] = HttpMethods.Delete; + _methodNames[(byte)HttpMethod.Get] = HttpMethods.Get; + _methodNames[(byte)HttpMethod.Head] = HttpMethods.Head; + _methodNames[(byte)HttpMethod.Options] = HttpMethods.Options; + _methodNames[(byte)HttpMethod.Patch] = HttpMethods.Patch; + _methodNames[(byte)HttpMethod.Post] = HttpMethods.Post; + _methodNames[(byte)HttpMethod.Put] = HttpMethods.Put; + _methodNames[(byte)HttpMethod.Trace] = HttpMethods.Trace; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetKnownMethodIndex(ulong value) + { + const int magicNumer = 0x600000C; + var tmp = (int)value & magicNumer; + return ((tmp >> 2) | (tmp >> 23)) & 0xF; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs new file mode 100644 index 0000000000..0af41644ff --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs @@ -0,0 +1,541 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public static partial class HttpUtilities + { + private static readonly bool[] HostCharValidity = new bool[127]; + + public const string Http10Version = "HTTP/1.0"; + public const string Http11Version = "HTTP/1.1"; + public const string Http2Version = "HTTP/2"; + + public const string HttpUriScheme = "http://"; + public const string HttpsUriScheme = "https://"; + + // readonly primitive statics can be Jit'd to consts https://github.com/dotnet/coreclr/issues/1079 + private static readonly ulong _httpSchemeLong = GetAsciiStringAsLong(HttpUriScheme + "\0"); + private static readonly ulong _httpsSchemeLong = GetAsciiStringAsLong(HttpsUriScheme); + + private const uint _httpGetMethodInt = 542393671; // GetAsciiStringAsInt("GET "); const results in better codegen + + private const ulong _http10VersionLong = 3471766442030158920; // GetAsciiStringAsLong("HTTP/1.0"); const results in better codegen + private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen + + // Only called from the static constructor + private static void InitializeHostCharValidity() + { + // Matches Http.Sys + // Matches RFC 3986 except "*" / "+" / "," / ";" / "=" and "%" HEXDIG HEXDIG which are not allowed by Http.Sys + HostCharValidity['!'] = true; + HostCharValidity['$'] = true; + HostCharValidity['&'] = true; + HostCharValidity['\''] = true; + HostCharValidity['('] = true; + HostCharValidity[')'] = true; + HostCharValidity['-'] = true; + HostCharValidity['.'] = true; + HostCharValidity['_'] = true; + HostCharValidity['~'] = true; + for (var ch = '0'; ch <= '9'; ch++) + { + HostCharValidity[ch] = true; + } + for (var ch = 'A'; ch <= 'Z'; ch++) + { + HostCharValidity[ch] = true; + } + for (var ch = 'a'; ch <= 'z'; ch++) + { + HostCharValidity[ch] = true; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void SetKnownMethod(ulong mask, ulong knownMethodUlong, HttpMethod knownMethod, int length) + { + _knownMethods[GetKnownMethodIndex(knownMethodUlong)] = new Tuple(mask, knownMethodUlong, knownMethod, length); + } + + private static void FillKnownMethodsGaps() + { + var knownMethods = _knownMethods; + var length = knownMethods.Length; + var invalidHttpMethod = new Tuple(_mask8Chars, 0ul, HttpMethod.Custom, 0); + for (int i = 0; i < length; i++) + { + if (knownMethods[i] == null) + { + knownMethods[i] = invalidHttpMethod; + } + } + } + + private static unsafe ulong GetAsciiStringAsLong(string str) + { + Debug.Assert(str.Length == 8, "String must be exactly 8 (ASCII) characters long."); + + var bytes = Encoding.ASCII.GetBytes(str); + + fixed (byte* ptr = &bytes[0]) + { + return *(ulong*)ptr; + } + } + + private static unsafe uint GetAsciiStringAsInt(string str) + { + Debug.Assert(str.Length == 4, "String must be exactly 4 (ASCII) characters long."); + + var bytes = Encoding.ASCII.GetBytes(str); + + fixed (byte* ptr = &bytes[0]) + { + return *(uint*)ptr; + } + } + + private static unsafe ulong GetMaskAsLong(byte[] bytes) + { + Debug.Assert(bytes.Length == 8, "Mask must be exactly 8 bytes long."); + + fixed (byte* ptr = bytes) + { + return *(ulong*)ptr; + } + } + + public static unsafe string GetAsciiStringNonNullCharacters(this Span span) + { + if (span.IsEmpty) + { + return string.Empty; + } + + var asciiString = new string('\0', span.Length); + + fixed (char* output = asciiString) + fixed (byte* buffer = &MemoryMarshal.GetReference(span)) + { + // This version if AsciiUtilities returns null if there are any null (0 byte) characters + // in the string + if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length)) + { + throw new InvalidOperationException(); + } + } + return asciiString; + } + + public static string GetAsciiStringEscaped(this Span span, int maxChars) + { + var sb = new StringBuilder(); + + for (var i = 0; i < Math.Min(span.Length, maxChars); i++) + { + var ch = span[i]; + sb.Append(ch < 0x20 || ch >= 0x7F ? $"\\x{ch:X2}" : ((char)ch).ToString()); + } + + if (span.Length > maxChars) + { + sb.Append("..."); + } + + return sb.ToString(); + } + + /// + /// Checks that up to 8 bytes from correspond to a known HTTP method. + /// + /// + /// A "known HTTP method" can be an HTTP method name defined in the HTTP/1.1 RFC. + /// Since all of those fit in at most 8 bytes, they can be optimally looked up by reading those bytes as a long. Once + /// in that format, it can be checked against the known method. + /// The Known Methods (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE) are all less than 8 bytes + /// and will be compared with the required space. A mask is used if the Known method is less than 8 bytes. + /// To optimize performance the GET method will be checked first. + /// + /// true if the input matches a known string, false otherwise. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool GetKnownMethod(this Span span, out HttpMethod method, out int length) + { + fixed (byte* data = &MemoryMarshal.GetReference(span)) + { + method = GetKnownMethod(data, span.Length, out length); + return method != HttpMethod.Custom; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe HttpMethod GetKnownMethod(byte* data, int length, out int methodLength) + { + methodLength = 0; + if (length < sizeof(uint)) + { + return HttpMethod.Custom; + } + else if (*(uint*)data == _httpGetMethodInt) + { + methodLength = 3; + return HttpMethod.Get; + } + else if (length < sizeof(ulong)) + { + return HttpMethod.Custom; + } + else + { + var value = *(ulong*)data; + var key = GetKnownMethodIndex(value); + var x = _knownMethods[key]; + + if (x != null && (value & x.Item1) == x.Item2) + { + methodLength = x.Item4; + return x.Item3; + } + } + + return HttpMethod.Custom; + } + + /// + /// Parses string for a known HTTP method. + /// + /// + /// A "known HTTP method" can be an HTTP method name defined in the HTTP/1.1 RFC. + /// The Known Methods (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE) + /// + /// + public static HttpMethod GetKnownMethod(string value) + { + // Called by http/2 + if (value == null) + { + throw new ArgumentNullException(nameof(value)); + } + + var length = value.Length; + if (length == 0) + { + throw new ArgumentException(nameof(value)); + } + + // Start with custom and assign if known method is found + var method = HttpMethod.Custom; + + var firstChar = value[0]; + if (length == 3) + { + if (firstChar == 'G' && string.Equals(value, HttpMethods.Get, StringComparison.Ordinal)) + { + method = HttpMethod.Get; + } + else if (firstChar == 'P' && string.Equals(value, HttpMethods.Put, StringComparison.Ordinal)) + { + method = HttpMethod.Put; + } + } + else if (length == 4) + { + if (firstChar == 'H' && string.Equals(value, HttpMethods.Head, StringComparison.Ordinal)) + { + method = HttpMethod.Head; + } + else if(firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal)) + { + method = HttpMethod.Post; + } + } + else if (length == 5) + { + if (firstChar == 'T' && string.Equals(value, HttpMethods.Trace, StringComparison.Ordinal)) + { + method = HttpMethod.Trace; + } + else if(firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal)) + { + method = HttpMethod.Patch; + } + } + else if (length == 6) + { + if (firstChar == 'D' && string.Equals(value, HttpMethods.Delete, StringComparison.Ordinal)) + { + method = HttpMethod.Delete; + } + } + else if (length == 7) + { + if (firstChar == 'C' && string.Equals(value, HttpMethods.Connect, StringComparison.Ordinal)) + { + method = HttpMethod.Connect; + } + else if (firstChar == 'O' && string.Equals(value, HttpMethods.Options, StringComparison.Ordinal)) + { + method = HttpMethod.Options; + } + } + + return method; + } + + /// + /// Checks 9 bytes from correspond to a known HTTP version. + /// + /// + /// A "known HTTP version" Is is either HTTP/1.0 or HTTP/1.1. + /// Since those fit in 8 bytes, they can be optimally looked up by reading those bytes as a long. Once + /// in that format, it can be checked against the known versions. + /// The Known versions will be checked with the required '\r'. + /// To optimize performance the HTTP/1.1 will be checked first. + /// + /// true if the input matches a known string, false otherwise. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool GetKnownVersion(this Span span, out HttpVersion knownVersion, out byte length) + { + fixed (byte* data = &MemoryMarshal.GetReference(span)) + { + knownVersion = GetKnownVersion(data, span.Length); + if (knownVersion != HttpVersion.Unknown) + { + length = sizeof(ulong); + return true; + } + + length = 0; + return false; + } + } + + /// + /// Checks 9 bytes from correspond to a known HTTP version. + /// + /// + /// A "known HTTP version" Is is either HTTP/1.0 or HTTP/1.1. + /// Since those fit in 8 bytes, they can be optimally looked up by reading those bytes as a long. Once + /// in that format, it can be checked against the known versions. + /// The Known versions will be checked with the required '\r'. + /// To optimize performance the HTTP/1.1 will be checked first. + /// + /// true if the input matches a known string, false otherwise. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe HttpVersion GetKnownVersion(byte* location, int length) + { + HttpVersion knownVersion; + var version = *(ulong*)location; + if (length < sizeof(ulong) + 1 || location[sizeof(ulong)] != (byte)'\r') + { + knownVersion = HttpVersion.Unknown; + } + else if (version == _http11VersionLong) + { + knownVersion = HttpVersion.Http11; + } + else if (version == _http10VersionLong) + { + knownVersion = HttpVersion.Http10; + } + else + { + knownVersion = HttpVersion.Unknown; + } + + return knownVersion; + } + + /// + /// Checks 8 bytes from that correspond to 'http://' or 'https://' + /// + /// The span + /// A reference to the known scheme, if the input matches any + /// True when memory starts with known http or https schema + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool GetKnownHttpScheme(this Span span, out HttpScheme knownScheme) + { + fixed (byte* data = &MemoryMarshal.GetReference(span)) + { + return GetKnownHttpScheme(data, span.Length, out knownScheme); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe bool GetKnownHttpScheme(byte* location, int length, out HttpScheme knownScheme) + { + if (length >= sizeof(ulong)) + { + var scheme = *(ulong*)location; + if ((scheme & _mask7Chars) == _httpSchemeLong) + { + knownScheme = HttpScheme.Http; + return true; + } + + if (scheme == _httpsSchemeLong) + { + knownScheme = HttpScheme.Https; + return true; + } + } + knownScheme = HttpScheme.Unknown; + return false; + } + + public static string VersionToString(HttpVersion httpVersion) + { + switch (httpVersion) + { + case HttpVersion.Http10: + return Http10Version; + case HttpVersion.Http11: + return Http11Version; + default: + return null; + } + } + public static string MethodToString(HttpMethod method) + { + int methodIndex = (int)method; + if (methodIndex >= 0 && methodIndex <= 8) + { + return _methodNames[methodIndex]; + } + return null; + } + + public static string SchemeToString(HttpScheme scheme) + { + switch (scheme) + { + case HttpScheme.Http: + return HttpUriScheme; + case HttpScheme.Https: + return HttpsUriScheme; + default: + return null; + } + } + + public static void ValidateHostHeader(string hostText) + { + if (string.IsNullOrEmpty(hostText)) + { + // The spec allows empty values + return; + } + + var firstChar = hostText[0]; + if (firstChar == '[') + { + // Tail call + ValidateIPv6Host(hostText); + } + else + { + if (firstChar == ':') + { + // Only a port + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + + // Enregister array + var hostCharValidity = HostCharValidity; + for (var i = 0; i < hostText.Length; i++) + { + if (!hostCharValidity[hostText[i]]) + { + // Tail call + ValidateHostPort(hostText, i); + return; + } + } + } + } + + // The lead '[' was already checked + private static void ValidateIPv6Host(string hostText) + { + for (var i = 1; i < hostText.Length; i++) + { + var ch = hostText[i]; + if (ch == ']') + { + // [::1] is the shortest valid IPv6 host + if (i < 4) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + else if (i + 1 < hostText.Length) + { + // Tail call + ValidateHostPort(hostText, i + 1); + } + return; + } + + if (!IsHex(ch) && ch != ':' && ch != '.') + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + } + + // Must contain a ']' + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + + private static void ValidateHostPort(string hostText, int offset) + { + var firstChar = hostText[offset]; + offset++; + if (firstChar != ':' || offset == hostText.Length) + { + // Must have at least one number after the colon if present. + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + + for (var i = offset; i < hostText.Length; i++) + { + if (!IsNumeric(hostText[i])) + { + BadHttpRequestException.Throw(RequestRejectionReason.InvalidHostHeader, hostText); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsNumeric(char ch) + { + // '0' <= ch && ch <= '9' + // (uint)(ch - '0') <= (uint)('9' - '0') + + // Subtract start of range '0' + // Cast to uint to change negative numbers to large numbers + // Check if less than 10 representing chars '0' - '9' + return (uint)(ch - '0') < 10u; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsHex(char ch) + { + return IsNumeric(ch) + // || ('a' <= ch && ch <= 'f') + // || ('A' <= ch && ch <= 'F'); + + // Lowercase indiscriminately (or with 32) + // Subtract start of range 'a' + // Cast to uint to change negative numbers to large numbers + // Check if less than 6 representing chars 'a' - 'f' + || (uint)((ch | 32) - 'a') < 6u; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IDebugger.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IDebugger.cs new file mode 100644 index 0000000000..cb1448fd4f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IDebugger.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public interface IDebugger + { + bool IsAttached { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IHeartbeatHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IHeartbeatHandler.cs new file mode 100644 index 0000000000..e6a355e829 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IHeartbeatHandler.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public interface IHeartbeatHandler + { + void OnHeartbeat(DateTimeOffset now); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IKestrelTrace.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IKestrelTrace.cs new file mode 100644 index 0000000000..caaf66cfae --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IKestrelTrace.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public interface IKestrelTrace : ILogger + { + void ConnectionStart(string connectionId); + + void ConnectionStop(string connectionId); + + void ConnectionPause(string connectionId); + + void ConnectionResume(string connectionId); + + void ConnectionRejected(string connectionId); + + void ConnectionKeepAlive(string connectionId); + + void ConnectionDisconnect(string connectionId); + + void RequestProcessingError(string connectionId, Exception ex); + + void ConnectionHeadResponseBodyWrite(string connectionId, long count); + + void NotAllConnectionsClosedGracefully(); + + void ConnectionBadRequest(string connectionId, BadHttpRequestException ex); + + void ApplicationError(string connectionId, string traceIdentifier, Exception ex); + + void NotAllConnectionsAborted(); + + void HeartbeatSlow(TimeSpan interval, DateTimeOffset now); + + void ApplicationNeverCompleted(string connectionId); + + void RequestBodyStart(string connectionId, string traceIdentifier); + + void RequestBodyDone(string connectionId, string traceIdentifier); + + void RequestBodyNotEntirelyRead(string connectionId, string traceIdentifier); + + void RequestBodyDrainTimedOut(string connectionId, string traceIdentifier); + + void RequestBodyMininumDataRateNotSatisfied(string connectionId, string traceIdentifier, double rate); + + void ResponseMininumDataRateNotSatisfied(string connectionId, string traceIdentifier); + + void ApplicationAbortedConnection(string connectionId, string traceIdentifier); + + void Http2ConnectionError(string connectionId, Http2ConnectionErrorException ex); + + void Http2StreamError(string connectionId, Http2StreamErrorException ex); + + void HPackDecodingError(string connectionId, int streamId, HPackDecodingException ex); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ISystemClock.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ISystemClock.cs new file mode 100644 index 0000000000..ddc5b1fd66 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ISystemClock.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + /// + /// Abstracts the system clock to facilitate testing. + /// + public interface ISystemClock + { + /// + /// Retrieves the current system time in UTC. + /// + DateTimeOffset UtcNow { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ITimeoutControl.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ITimeoutControl.cs new file mode 100644 index 0000000000..69175ec0b5 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ITimeoutControl.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public interface ITimeoutControl + { + void SetTimeout(long ticks, TimeoutAction timeoutAction); + void ResetTimeout(long ticks, TimeoutAction timeoutAction); + void CancelTimeout(); + + void StartTimingReads(); + void PauseTimingReads(); + void ResumeTimingReads(); + void StopTimingReads(); + void BytesRead(long count); + + void StartTimingWrite(long size); + void StopTimingWrite(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelEventSource.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelEventSource.cs new file mode 100644 index 0000000000..7faf486d1f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelEventSource.cs @@ -0,0 +1,114 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + [EventSource(Name = "Microsoft-AspNetCore-Server-Kestrel")] + public sealed class KestrelEventSource : EventSource + { + public static readonly KestrelEventSource Log = new KestrelEventSource(); + + private KestrelEventSource() + { + } + + // NOTE + // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They + // enable creating 'activities'. + // For more information, take a look at the following blog post: + // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ + // - A stop event's event id must be next one after its start event. + // - Avoid renaming methods or parameters marked with EventAttribute. EventSource uses these to form the event object. + + [NonEvent] + public void ConnectionStart(HttpConnection connection) + { + // avoid allocating strings unless this event source is enabled + if (IsEnabled()) + { + ConnectionStart( + connection.ConnectionId, + connection.LocalEndPoint?.ToString(), + connection.RemoteEndPoint?.ToString()); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Event(1, Level = EventLevel.Verbose)] + private void ConnectionStart(string connectionId, + string localEndPoint, + string remoteEndPoint) + { + WriteEvent( + 1, + connectionId, + localEndPoint, + remoteEndPoint + ); + } + + [NonEvent] + public void ConnectionStop(HttpConnection connection) + { + if (IsEnabled()) + { + ConnectionStop(connection.ConnectionId); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Event(2, Level = EventLevel.Verbose)] + private void ConnectionStop(string connectionId) + { + WriteEvent(2, connectionId); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Event(5, Level = EventLevel.Verbose)] + public void ConnectionRejected(string connectionId) + { + if (IsEnabled()) + { + WriteEvent(5, connectionId); + } + } + + [NonEvent] + public void RequestStart(HttpProtocol httpProtocol) + { + // avoid allocating the trace identifier unless logging is enabled + if (IsEnabled()) + { + RequestStart(httpProtocol.ConnectionIdFeature, httpProtocol.TraceIdentifier); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Event(3, Level = EventLevel.Verbose)] + private void RequestStart(string connectionId, string requestId) + { + WriteEvent(3, connectionId, requestId); + } + + [NonEvent] + public void RequestStop(HttpProtocol httpProtocol) + { + // avoid allocating the trace identifier unless logging is enabled + if (IsEnabled()) + { + RequestStop(httpProtocol.ConnectionIdFeature, httpProtocol.TraceIdentifier); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Event(4, Level = EventLevel.Verbose)] + private void RequestStop(string connectionId, string requestId) + { + WriteEvent(4, connectionId, requestId); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelTrace.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelTrace.cs new file mode 100644 index 0000000000..3ef1aa1721 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelTrace.cs @@ -0,0 +1,228 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class KestrelTrace : IKestrelTrace + { + private static readonly Action _connectionStart = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(ConnectionStart)), @"Connection id ""{ConnectionId}"" started."); + + private static readonly Action _connectionStop = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, nameof(ConnectionStop)), @"Connection id ""{ConnectionId}"" stopped."); + + private static readonly Action _connectionPause = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(ConnectionPause)), @"Connection id ""{ConnectionId}"" paused."); + + private static readonly Action _connectionResume = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ConnectionResume)), @"Connection id ""{ConnectionId}"" resumed."); + + private static readonly Action _connectionKeepAlive = + LoggerMessage.Define(LogLevel.Debug, new EventId(9, nameof(ConnectionKeepAlive)), @"Connection id ""{ConnectionId}"" completed keep alive response."); + + private static readonly Action _connectionDisconnect = + LoggerMessage.Define(LogLevel.Debug, new EventId(10, nameof(ConnectionDisconnect)), @"Connection id ""{ConnectionId}"" disconnecting."); + + private static readonly Action _applicationError = + LoggerMessage.Define(LogLevel.Error, new EventId(13, nameof(ApplicationError)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": An unhandled exception was thrown by the application."); + + private static readonly Action _notAllConnectionsClosedGracefully = + LoggerMessage.Define(LogLevel.Debug, new EventId(16, nameof(NotAllConnectionsClosedGracefully)), "Some connections failed to close gracefully during server shutdown."); + + private static readonly Action _connectionBadRequest = + LoggerMessage.Define(LogLevel.Information, new EventId(17, nameof(ConnectionBadRequest)), @"Connection id ""{ConnectionId}"" bad request data: ""{message}"""); + + private static readonly Action _connectionHeadResponseBodyWrite = + LoggerMessage.Define(LogLevel.Debug, new EventId(18, nameof(ConnectionHeadResponseBodyWrite)), @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response."); + + private static readonly Action _requestProcessingError = + LoggerMessage.Define(LogLevel.Information, new EventId(20, nameof(RequestProcessingError)), @"Connection id ""{ConnectionId}"" request processing ended abnormally."); + + private static readonly Action _notAllConnectionsAborted = + LoggerMessage.Define(LogLevel.Debug, new EventId(21, nameof(NotAllConnectionsAborted)), "Some connections failed to abort during server shutdown."); + + private static readonly Action _heartbeatSlow = + LoggerMessage.Define(LogLevel.Warning, new EventId(22, nameof(HeartbeatSlow)), @"Heartbeat took longer than ""{interval}"" at ""{now}""."); + + private static readonly Action _applicationNeverCompleted = + LoggerMessage.Define(LogLevel.Critical, new EventId(23, nameof(ApplicationNeverCompleted)), @"Connection id ""{ConnectionId}"" application never completed"); + + private static readonly Action _connectionRejected = + LoggerMessage.Define(LogLevel.Warning, new EventId(24, nameof(ConnectionRejected)), @"Connection id ""{ConnectionId}"" rejected because the maximum number of concurrent connections has been reached."); + + private static readonly Action _requestBodyStart = + LoggerMessage.Define(LogLevel.Debug, new EventId(25, nameof(RequestBodyStart)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": started reading request body."); + + private static readonly Action _requestBodyDone = + LoggerMessage.Define(LogLevel.Debug, new EventId(26, nameof(RequestBodyDone)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": done reading request body."); + + private static readonly Action _requestBodyMinimumDataRateNotSatisfied = + LoggerMessage.Define(LogLevel.Information, new EventId(27, nameof(RequestBodyMininumDataRateNotSatisfied)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": the request timed out because it was not sent by the client at a minimum of {Rate} bytes/second."); + + private static readonly Action _responseMinimumDataRateNotSatisfied = + LoggerMessage.Define(LogLevel.Information, new EventId(28, nameof(ResponseMininumDataRateNotSatisfied)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": the connection was closed because the response was not read by the client at the specified minimum data rate."); + + private static readonly Action _http2ConnectionError = + LoggerMessage.Define(LogLevel.Information, new EventId(29, nameof(Http2ConnectionError)), @"Connection id ""{ConnectionId}"": HTTP/2 connection error."); + + private static readonly Action _http2StreamError = + LoggerMessage.Define(LogLevel.Information, new EventId(30, nameof(Http2StreamError)), @"Connection id ""{ConnectionId}"": HTTP/2 stream error."); + + private static readonly Action _hpackDecodingError = + LoggerMessage.Define(LogLevel.Information, new EventId(31, nameof(HPackDecodingError)), @"Connection id ""{ConnectionId}"": HPACK decoding error while decoding headers for stream ID {StreamId}."); + + private static readonly Action _requestBodyNotEntirelyRead = + LoggerMessage.Define(LogLevel.Information, new EventId(32, nameof(RequestBodyNotEntirelyRead)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": the application completed without reading the entire request body."); + + private static readonly Action _requestBodyDrainTimedOut = + LoggerMessage.Define(LogLevel.Information, new EventId(33, nameof(RequestBodyDrainTimedOut)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": automatic draining of the request body timed out after taking over 5 seconds."); + + private static readonly Action _applicationAbortedConnection = + LoggerMessage.Define(LogLevel.Information, new EventId(34, nameof(RequestBodyDrainTimedOut)), @"Connection id ""{ConnectionId}"", Request id ""{TraceIdentifier}"": the application aborted the connection."); + + protected readonly ILogger _logger; + + public KestrelTrace(ILogger logger) + { + _logger = logger; + } + + public virtual void ConnectionStart(string connectionId) + { + _connectionStart(_logger, connectionId, null); + } + + public virtual void ConnectionStop(string connectionId) + { + _connectionStop(_logger, connectionId, null); + } + + public virtual void ConnectionPause(string connectionId) + { + _connectionPause(_logger, connectionId, null); + } + + public virtual void ConnectionResume(string connectionId) + { + _connectionResume(_logger, connectionId, null); + } + + public virtual void ConnectionKeepAlive(string connectionId) + { + _connectionKeepAlive(_logger, connectionId, null); + } + + public virtual void ConnectionRejected(string connectionId) + { + _connectionRejected(_logger, connectionId, null); + } + + public virtual void ConnectionDisconnect(string connectionId) + { + _connectionDisconnect(_logger, connectionId, null); + } + + public virtual void ApplicationError(string connectionId, string traceIdentifier, Exception ex) + { + _applicationError(_logger, connectionId, traceIdentifier, ex); + } + + public virtual void ConnectionHeadResponseBodyWrite(string connectionId, long count) + { + _connectionHeadResponseBodyWrite(_logger, connectionId, count, null); + } + + public virtual void NotAllConnectionsClosedGracefully() + { + _notAllConnectionsClosedGracefully(_logger, null); + } + + public virtual void ConnectionBadRequest(string connectionId, BadHttpRequestException ex) + { + _connectionBadRequest(_logger, connectionId, ex.Message, ex); + } + + public virtual void RequestProcessingError(string connectionId, Exception ex) + { + _requestProcessingError(_logger, connectionId, ex); + } + + public virtual void NotAllConnectionsAborted() + { + _notAllConnectionsAborted(_logger, null); + } + + public virtual void HeartbeatSlow(TimeSpan interval, DateTimeOffset now) + { + _heartbeatSlow(_logger, interval, now, null); + } + + public virtual void ApplicationNeverCompleted(string connectionId) + { + _applicationNeverCompleted(_logger, connectionId, null); + } + + public virtual void RequestBodyStart(string connectionId, string traceIdentifier) + { + _requestBodyStart(_logger, connectionId, traceIdentifier, null); + } + + public virtual void RequestBodyDone(string connectionId, string traceIdentifier) + { + _requestBodyDone(_logger, connectionId, traceIdentifier, null); + } + + public virtual void RequestBodyMininumDataRateNotSatisfied(string connectionId, string traceIdentifier, double rate) + { + _requestBodyMinimumDataRateNotSatisfied(_logger, connectionId, traceIdentifier, rate, null); + } + + public virtual void RequestBodyNotEntirelyRead(string connectionId, string traceIdentifier) + { + _requestBodyNotEntirelyRead(_logger, connectionId, traceIdentifier, null); + } + + public virtual void RequestBodyDrainTimedOut(string connectionId, string traceIdentifier) + { + _requestBodyDrainTimedOut(_logger, connectionId, traceIdentifier, null); + } + + public virtual void ResponseMininumDataRateNotSatisfied(string connectionId, string traceIdentifier) + { + _responseMinimumDataRateNotSatisfied(_logger, connectionId, traceIdentifier, null); + } + + public virtual void ApplicationAbortedConnection(string connectionId, string traceIdentifier) + { + _applicationAbortedConnection(_logger, connectionId, traceIdentifier, null); + } + + public virtual void Http2ConnectionError(string connectionId, Http2ConnectionErrorException ex) + { + _http2ConnectionError(_logger, connectionId, ex); + } + + public virtual void Http2StreamError(string connectionId, Http2StreamErrorException ex) + { + _http2StreamError(_logger, connectionId, ex); + } + + public virtual void HPackDecodingError(string connectionId, int streamId, HPackDecodingException ex) + { + _hpackDecodingError(_logger, connectionId, streamId, ex); + } + + public virtual void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + => _logger.Log(logLevel, eventId, state, exception, formatter); + + public virtual bool IsEnabled(LogLevel logLevel) => _logger.IsEnabled(logLevel); + + public virtual IDisposable BeginScope(TState state) => _logger.BeginScope(state); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ReadOnlyStream.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ReadOnlyStream.cs new file mode 100644 index 0000000000..cf4b02a41f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ReadOnlyStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public abstract class ReadOnlyStream : Stream + { + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override int WriteTimeout + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ResourceCounter.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ResourceCounter.cs new file mode 100644 index 0000000000..68c96e25ac --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ResourceCounter.cs @@ -0,0 +1,77 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public abstract class ResourceCounter + { + public abstract bool TryLockOne(); + public abstract void ReleaseOne(); + + public static ResourceCounter Unlimited { get; } = new UnlimitedCounter(); + public static ResourceCounter Quota(long amount) => new FiniteCounter(amount); + + private class UnlimitedCounter : ResourceCounter + { + public override bool TryLockOne() => true; + public override void ReleaseOne() + { + } + } + + internal class FiniteCounter : ResourceCounter + { + private readonly long _max; + private long _count; + + public FiniteCounter(long max) + { + if (max < 0) + { + throw new ArgumentOutOfRangeException(CoreStrings.NonNegativeNumberRequired); + } + + _max = max; + } + + public override bool TryLockOne() + { + var count = _count; + + // Exit if count == MaxValue as incrementing would overflow. + + while (count < _max && count != long.MaxValue) + { + var prev = Interlocked.CompareExchange(ref _count, count + 1, count); + if (prev == count) + { + return true; + } + + // Another thread changed the count before us. Try again with the new counter value. + count = prev; + } + + return false; + } + + public override void ReleaseOne() + { + Interlocked.Decrement(ref _count); + + Debug.Assert(_count >= 0, "Resource count is negative. More resources were released than were locked."); + } + + // for testing + internal long Count + { + get => _count; + set => _count = value; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StackTraceHiddenAttribute.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StackTraceHiddenAttribute.cs new file mode 100644 index 0000000000..cf6e608d7b --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StackTraceHiddenAttribute.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Diagnostics +{ + /// + /// Attribute to add to non-returning throw only methods, + /// to restore the stack trace back to what it would be if the throw was in-place + /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Struct, Inherited = false)] + internal sealed class StackTraceHiddenAttribute : Attribute + { + // https://github.com/dotnet/coreclr/blob/eb54e48b13fdfb7233b7bcd32b93792ba3e89f0c/src/mscorlib/shared/System/Diagnostics/StackTraceHiddenAttribute.cs + public StackTraceHiddenAttribute() { } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs new file mode 100644 index 0000000000..a439928e5e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/Streams.cs @@ -0,0 +1,71 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class Streams + { + private static readonly ThrowingWriteOnlyStream _throwingResponseStream + = new ThrowingWriteOnlyStream(new InvalidOperationException(CoreStrings.ResponseStreamWasUpgraded)); + private readonly HttpResponseStream _response; + private readonly HttpRequestStream _request; + private readonly WrappingStream _upgradeableResponse; + private readonly HttpRequestStream _emptyRequest; + private readonly Stream _upgradeStream; + + public Streams(IHttpBodyControlFeature bodyControl, IHttpResponseControl httpResponseControl) + { + _request = new HttpRequestStream(bodyControl); + _emptyRequest = new HttpRequestStream(bodyControl); + _response = new HttpResponseStream(bodyControl, httpResponseControl); + _upgradeableResponse = new WrappingStream(_response); + _upgradeStream = new HttpUpgradeStream(_request, _response); + } + + public Stream Upgrade() + { + // causes writes to context.Response.Body to throw + _upgradeableResponse.SetInnerStream(_throwingResponseStream); + // _upgradeStream always uses _response + return _upgradeStream; + } + + public (Stream request, Stream response) Start(MessageBody body) + { + _request.StartAcceptingReads(body); + _emptyRequest.StartAcceptingReads(MessageBody.ZeroContentLengthClose); + _response.StartAcceptingWrites(); + + if (body.RequestUpgrade) + { + // until Upgrade() is called, context.Response.Body should use the normal output stream + _upgradeableResponse.SetInnerStream(_response); + // upgradeable requests should never have a request body + return (_emptyRequest, _upgradeableResponse); + } + else + { + return (_request, _response); + } + } + + public void Stop() + { + _request.StopAcceptingReads(); + _emptyRequest.StopAcceptingReads(); + _response.StopAcceptingWrites(); + } + + public void Abort(Exception error) + { + _request.Abort(error); + _emptyRequest.Abort(error); + _response.Abort(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StringUtilities.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StringUtilities.cs new file mode 100644 index 0000000000..95073a97de --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/StringUtilities.cs @@ -0,0 +1,190 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Numerics; +using System.Runtime.CompilerServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class StringUtilities + { + public static unsafe bool TryGetAsciiString(byte* input, char* output, int count) + { + // Calculate end position + var end = input + count; + // Start as valid + var isValid = true; + + do + { + // If Vector not-accelerated or remaining less than vector size + if (!Vector.IsHardwareAccelerated || input > end - Vector.Count) + { + if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination + { + // 64-bit: Loop longs by default + while (input <= end - sizeof(long)) + { + isValid &= CheckBytesInAsciiRange(((long*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + + input += sizeof(long); + output += sizeof(long); + } + if (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + else + { + // 32-bit: Loop ints by default + while (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + if (input <= end - sizeof(short)) + { + isValid &= CheckBytesInAsciiRange(((short*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + + input += sizeof(short); + output += sizeof(short); + } + if (input < end) + { + isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]); + output[0] = (char)input[0]; + } + + return isValid; + } + + // do/while as entry condition already checked + do + { + var vector = Unsafe.AsRef>(input); + isValid &= CheckBytesInAsciiRange(vector); + Vector.Widen( + vector, + out Unsafe.AsRef>(output), + out Unsafe.AsRef>(output + Vector.Count)); + + input += Vector.Count; + output += Vector.Count; + } while (input <= end - Vector.Count); + + // Vector path done, loop back to do non-Vector + // If is a exact multiple of vector size, bail now + } while (input < end); + + return isValid; + } + + private static readonly string _encode16Chars = "0123456789ABCDEF"; + + /// + /// A faster version of String.Concat(, , .ToString("X8")) + /// + /// + /// + /// + /// + public static unsafe string ConcatAsHexSuffix(string str, char separator, uint number) + { + var length = 1 + 8; + if (str != null) + { + length += str.Length; + } + + // stackalloc to allocate array on stack rather than heap + char* charBuffer = stackalloc char[length]; + + var i = 0; + if (str != null) + { + for (i = 0; i < str.Length; i++) + { + charBuffer[i] = str[i]; + } + } + + charBuffer[i] = separator; + + charBuffer[i + 1] = _encode16Chars[(int)(number >> 28) & 0xF]; + charBuffer[i + 2] = _encode16Chars[(int)(number >> 24) & 0xF]; + charBuffer[i + 3] = _encode16Chars[(int)(number >> 20) & 0xF]; + charBuffer[i + 4] = _encode16Chars[(int)(number >> 16) & 0xF]; + charBuffer[i + 5] = _encode16Chars[(int)(number >> 12) & 0xF]; + charBuffer[i + 6] = _encode16Chars[(int)(number >> 8) & 0xF]; + charBuffer[i + 7] = _encode16Chars[(int)(number >> 4) & 0xF]; + charBuffer[i + 8] = _encode16Chars[(int)number & 0xF]; + + // string ctor overload that takes char* + return new string(charBuffer, 0, length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push + private static bool CheckBytesInAsciiRange(Vector check) + { + // Vectorized byte range check, signed byte > 0 for 1-127 + return Vector.GreaterThanAll(check, Vector.Zero); + } + + // Validate: bytes != 0 && bytes <= 127 + // Subtract 1 from all bytes to move 0 to high bits + // bitwise or with self to catch all > 127 bytes + // mask off high bits and check if 0 + + [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push + private static bool CheckBytesInAsciiRange(long check) + { + const long HighBits = unchecked((long)0x8080808080808080L); + return (((check - 0x0101010101010101L) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(int check) + { + const int HighBits = unchecked((int)0x80808080); + return (((check - 0x01010101) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(short check) + { + const short HighBits = unchecked((short)0x8080); + return (((short)(check - 0x0101) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(sbyte check) + => check > 0; + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/SystemClock.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/SystemClock.cs new file mode 100644 index 0000000000..1284ef9f4f --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/SystemClock.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + /// + /// Provides access to the normal system clock. + /// + internal class SystemClock : ISystemClock + { + /// + /// Retrieves the current system time in UTC. + /// + public DateTimeOffset UtcNow + { + get + { + return DateTimeOffset.UtcNow; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ThrowingWriteOnlyStream.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ThrowingWriteOnlyStream.cs new file mode 100644 index 0000000000..aff2b7a6d4 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ThrowingWriteOnlyStream.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class ThrowingWriteOnlyStream : WriteOnlyStream + { + private readonly Exception _exception; + + public ThrowingWriteOnlyStream(Exception exception) + { + _exception = exception; + } + + public override bool CanSeek => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + => throw _exception; + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw _exception; + + public override void Flush() + => throw _exception; + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + public override void SetLength(long value) + => throw new NotSupportedException(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TimeoutAction.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TimeoutAction.cs new file mode 100644 index 0000000000..b2aa9df8cf --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TimeoutAction.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public enum TimeoutAction + { + StopProcessingNextRequest, + SendTimeoutResponse, + AbortConnection, + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/UriUtilities.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/UriUtilities.cs new file mode 100644 index 0000000000..272b30cabf --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/UriUtilities.cs @@ -0,0 +1,35 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class UriUtilities + { + /// + /// Returns true if character is valid in the 'authority' section of a URI. + /// + /// + /// The character + /// + public static bool IsValidAuthorityCharacter(byte ch) + { + // Examples: + // microsoft.com + // hostname:8080 + // [::]:8080 + // [fe80::] + // 127.0.0.1 + // user@host.com + // user:password@host.com + return + (ch >= '0' && ch <= '9') || + (ch >= 'A' && ch <= 'Z') || + (ch >= 'a' && ch <= 'z') || + ch == ':' || + ch == '.' || + ch == '[' || + ch == ']' || + ch == '@'; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WrappingStream.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WrappingStream.cs new file mode 100644 index 0000000000..9485392825 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WrappingStream.cs @@ -0,0 +1,138 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class WrappingStream : Stream + { + private Stream _inner; + private bool _disposed; + + public WrappingStream(Stream inner) + { + _inner = inner; + } + + public void SetInnerStream(Stream inner) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(WrappingStream)); + } + + _inner = inner; + } + + public override bool CanRead => _inner.CanRead; + + public override bool CanSeek => _inner.CanSeek; + + public override bool CanWrite => _inner.CanWrite; + + public override bool CanTimeout => _inner.CanTimeout; + + public override long Length => _inner.Length; + + public override long Position + { + get => _inner.Position; + set => _inner.Position = value; + } + + public override int ReadTimeout + { + get => _inner.ReadTimeout; + set => _inner.ReadTimeout = value; + } + + public override int WriteTimeout + { + get => _inner.WriteTimeout; + set => _inner.WriteTimeout = value; + } + + public override void Flush() + => _inner.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) + => _inner.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) + => _inner.Read(buffer, offset, count); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _inner.ReadAsync(buffer, offset, count, cancellationToken); + +#if NETCOREAPP2_1 + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + => _inner.ReadAsync(destination, cancellationToken); +#endif + + public override int ReadByte() + => _inner.ReadByte(); + + public override long Seek(long offset, SeekOrigin origin) + => _inner.Seek(offset, origin); + + public override void SetLength(long value) + => _inner.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) + => _inner.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _inner.WriteAsync(buffer, offset, count, cancellationToken); + +#if NETCOREAPP2_1 + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + => _inner.WriteAsync(source, cancellationToken); +#endif + + public override void WriteByte(byte value) + => _inner.WriteByte(value); + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + => _inner.CopyToAsync(destination, bufferSize, cancellationToken); + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + => _inner.BeginRead(buffer, offset, count, callback, state); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + => _inner.BeginWrite(buffer, offset, count, callback, state); + + public override int EndRead(IAsyncResult asyncResult) + => _inner.EndRead(asyncResult); + + public override void EndWrite(IAsyncResult asyncResult) + => _inner.EndWrite(asyncResult); + + public override object InitializeLifetimeService() + => _inner.InitializeLifetimeService(); + + public override void Close() + => _inner.Close(); + + public override bool Equals(object obj) + => _inner.Equals(obj); + + public override int GetHashCode() + => _inner.GetHashCode(); + + public override string ToString() + => _inner.ToString(); + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _disposed = true; + _inner.Dispose(); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WriteOnlyStream.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WriteOnlyStream.cs new file mode 100644 index 0000000000..c7042e2bb0 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/WriteOnlyStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public abstract class WriteOnlyStream : Stream + { + public override bool CanRead => false; + + public override bool CanWrite => true; + + public override int ReadTimeout + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/KestrelServerOptionsSetup.cs b/src/Servers/Kestrel/Core/src/Internal/KestrelServerOptionsSetup.cs new file mode 100644 index 0000000000..18c96e2039 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/KestrelServerOptionsSetup.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class KestrelServerOptionsSetup : IConfigureOptions + { + private IServiceProvider _services; + + public KestrelServerOptionsSetup(IServiceProvider services) + { + _services = services; + } + + public void Configure(KestrelServerOptions options) + { + options.ApplicationServices = _services; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/LoggerExtensions.cs b/src/Servers/Kestrel/Core/src/Internal/LoggerExtensions.cs new file mode 100644 index 0000000000..0a4c6f3a5e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/LoggerExtensions.cs @@ -0,0 +1,30 @@ +using System; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Internal +{ + internal static class LoggerExtensions + { + // Category: DefaultHttpsProvider + private static readonly Action _locatedDevelopmentCertificate = + LoggerMessage.Define(LogLevel.Debug, new EventId(0, nameof(LocatedDevelopmentCertificate)), "Using development certificate: {certificateSubjectName} (Thumbprint: {certificateThumbprint})"); + + private static readonly Action _unableToLocateDevelopmentCertificate = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(UnableToLocateDevelopmentCertificate)), "Unable to locate an appropriate development https certificate."); + + private static readonly Action _failedToLocateDevelopmentCertificateFile = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, nameof(FailedToLocateDevelopmentCertificateFile)), "Failed to locate the development https certificate at '{certificatePath}'."); + + private static readonly Action _failedToLoadDevelopmentCertificate = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(FailedToLoadDevelopmentCertificate)), "Failed to load the development https certificate at '{certificatePath}'."); + + public static void LocatedDevelopmentCertificate(this ILogger logger, X509Certificate2 certificate) => _locatedDevelopmentCertificate(logger, certificate.Subject, certificate.Thumbprint, null); + + public static void UnableToLocateDevelopmentCertificate(this ILogger logger) => _unableToLocateDevelopmentCertificate(logger, null); + + public static void FailedToLocateDevelopmentCertificateFile(this ILogger logger, string certificatePath) => _failedToLocateDevelopmentCertificateFile(logger, certificatePath, null); + + public static void FailedToLoadDevelopmentCertificate(this ILogger logger, string certificatePath) => _failedToLoadDevelopmentCertificate(logger, certificatePath, null); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs new file mode 100644 index 0000000000..f8bcd13cde --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using Microsoft.AspNetCore.Hosting.Server.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class ServerAddressesFeature : IServerAddressesFeature + { + public ICollection Addresses { get; } = new List(); + public bool PreferHostingUrls { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs b/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs new file mode 100644 index 0000000000..1020a6fdbd --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class ServiceContext + { + public IKestrelTrace Log { get; set; } + + public PipeScheduler Scheduler { get; set; } + + public IHttpParser HttpParser { get; set; } + + public ISystemClock SystemClock { get; set; } + + public DateHeaderValueManager DateHeaderValueManager { get; set; } + + public HttpConnectionManager ConnectionManager { get; set; } + + public KestrelServerOptions ServerOptions { get; set; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/TlsConnectionFeature.cs b/src/Servers/Kestrel/Core/src/Internal/TlsConnectionFeature.cs new file mode 100644 index 0000000000..a914024131 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/TlsConnectionFeature.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Https.Internal +{ + internal class TlsConnectionFeature : ITlsConnectionFeature, ITlsApplicationProtocolFeature + { + public X509Certificate2 ClientCertificate { get; set; } + + public ReadOnlyMemory ApplicationProtocol { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs b/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs new file mode 100644 index 0000000000..60204443f6 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs @@ -0,0 +1,374 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Certificates.Generation; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel +{ + public class KestrelConfigurationLoader + { + internal KestrelConfigurationLoader(KestrelServerOptions options, IConfiguration configuration) + { + Options = options ?? throw new ArgumentNullException(nameof(options)); + Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + } + + public KestrelServerOptions Options { get; } + public IConfiguration Configuration { get; } + private IDictionary> EndpointConfigurations { get; } + = new Dictionary>(0, StringComparer.OrdinalIgnoreCase); + // Actions that will be delayed until Load so that they aren't applied if the configuration loader is replaced. + private IList EndpointsToAdd { get; } = new List(); + + /// + /// Specifies a configuration Action to run when an endpoint with the given name is loaded from configuration. + /// + public KestrelConfigurationLoader Endpoint(string name, Action configureOptions) + { + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentNullException(nameof(name)); + } + + EndpointConfigurations[name] = configureOptions ?? throw new ArgumentNullException(nameof(configureOptions)); + return this; + } + + /// + /// Bind to given IP address and port. + /// + public KestrelConfigurationLoader Endpoint(IPAddress address, int port) => Endpoint(address, port, _ => { }); + + /// + /// Bind to given IP address and port. + /// + public KestrelConfigurationLoader Endpoint(IPAddress address, int port, Action configure) + { + if (address == null) + { + throw new ArgumentNullException(nameof(address)); + } + + return Endpoint(new IPEndPoint(address, port), configure); + } + + /// + /// Bind to given IP endpoint. + /// + public KestrelConfigurationLoader Endpoint(IPEndPoint endPoint) => Endpoint(endPoint, _ => { }); + + /// + /// Bind to given IP address and port. + /// + public KestrelConfigurationLoader Endpoint(IPEndPoint endPoint, Action configure) + { + if (endPoint == null) + { + throw new ArgumentNullException(nameof(endPoint)); + } + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + EndpointsToAdd.Add(() => + { + Options.Listen(endPoint, configure); + }); + + return this; + } + + /// + /// Listens on ::1 and 127.0.0.1 with the given port. Requesting a dynamic port by specifying 0 is not supported + /// for this type of endpoint. + /// + public KestrelConfigurationLoader LocalhostEndpoint(int port) => LocalhostEndpoint(port, options => { }); + + /// + /// Listens on ::1 and 127.0.0.1 with the given port. Requesting a dynamic port by specifying 0 is not supported + /// for this type of endpoint. + /// + public KestrelConfigurationLoader LocalhostEndpoint(int port, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + EndpointsToAdd.Add(() => + { + Options.ListenLocalhost(port, configure); + }); + + return this; + } + + /// + /// Listens on all IPs using IPv6 [::], or IPv4 0.0.0.0 if IPv6 is not supported. + /// + public KestrelConfigurationLoader AnyIPEndpoint(int port) => AnyIPEndpoint(port, options => { }); + + /// + /// Listens on all IPs using IPv6 [::], or IPv4 0.0.0.0 if IPv6 is not supported. + /// + public KestrelConfigurationLoader AnyIPEndpoint(int port, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + EndpointsToAdd.Add(() => + { + Options.ListenAnyIP(port, configure); + }); + + return this; + } + + /// + /// Bind to given Unix domain socket path. + /// + public KestrelConfigurationLoader UnixSocketEndpoint(string socketPath) => UnixSocketEndpoint(socketPath, _ => { }); + + /// + /// Bind to given Unix domain socket path. + /// + public KestrelConfigurationLoader UnixSocketEndpoint(string socketPath, Action configure) + { + if (socketPath == null) + { + throw new ArgumentNullException(nameof(socketPath)); + } + if (socketPath.Length == 0 || socketPath[0] != '/') + { + throw new ArgumentException(CoreStrings.UnixSocketPathMustBeAbsolute, nameof(socketPath)); + } + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + EndpointsToAdd.Add(() => + { + Options.ListenUnixSocket(socketPath, configure); + }); + + return this; + } + + /// + /// Open a socket file descriptor. + /// + public KestrelConfigurationLoader HandleEndpoint(ulong handle) => HandleEndpoint(handle, _ => { }); + + /// + /// Open a socket file descriptor. + /// + public KestrelConfigurationLoader HandleEndpoint(ulong handle, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + EndpointsToAdd.Add(() => + { + Options.ListenHandle(handle, configure); + }); + + return this; + } + + public void Load() + { + if (Options.ConfigurationLoader == null) + { + // The loader has already been run. + return; + } + Options.ConfigurationLoader = null; + + var configReader = new ConfigurationReader(Configuration); + + LoadDefaultCert(configReader); + + foreach (var endpoint in configReader.Endpoints) + { + var listenOptions = AddressBinder.ParseAddress(endpoint.Url, out var https); + Options.ApplyEndpointDefaults(listenOptions); + + // Compare to UseHttps(httpsOptions => { }) + var httpsOptions = new HttpsConnectionAdapterOptions(); + if (https) + { + // Defaults + Options.ApplyHttpsDefaults(httpsOptions); + + // Specified + httpsOptions.ServerCertificate = LoadCertificate(endpoint.Certificate, endpoint.Name) + ?? httpsOptions.ServerCertificate; + + // Fallback + Options.ApplyDefaultCert(httpsOptions); + } + + if (EndpointConfigurations.TryGetValue(endpoint.Name, out var configureEndpoint)) + { + var endpointConfig = new EndpointConfiguration(https, listenOptions, httpsOptions, endpoint.ConfigSection); + configureEndpoint(endpointConfig); + } + + // EndpointDefaults or configureEndpoint may have added an https adapter. + if (https && !listenOptions.ConnectionAdapters.Any(f => f.IsHttps)) + { + if (httpsOptions.ServerCertificate == null && httpsOptions.ServerCertificateSelector == null) + { + throw new InvalidOperationException(CoreStrings.NoCertSpecifiedNoDevelopmentCertificateFound); + } + + listenOptions.UseHttps(httpsOptions); + } + + Options.ListenOptions.Add(listenOptions); + } + + foreach (var action in EndpointsToAdd) + { + action(); + } + } + + private void LoadDefaultCert(ConfigurationReader configReader) + { + if (configReader.Certificates.TryGetValue("Default", out var defaultCertConfig)) + { + var defaultCert = LoadCertificate(defaultCertConfig, "Default"); + if (defaultCert != null) + { + Options.DefaultCertificate = defaultCert; + } + } + else + { + var logger = Options.ApplicationServices.GetRequiredService>(); + var certificate = FindDeveloperCertificateFile(configReader, logger); + if (certificate != null) + { + logger.LocatedDevelopmentCertificate(certificate); + Options.DefaultCertificate = certificate; + } + } + } + + private X509Certificate2 FindDeveloperCertificateFile(ConfigurationReader configReader, ILogger logger) + { + string certificatePath = null; + try + { + if (configReader.Certificates.TryGetValue("Development", out var certificateConfig) && + certificateConfig.Path == null && + certificateConfig.Password != null && + TryGetCertificatePath(out certificatePath) && + File.Exists(certificatePath)) + { + var certificate = new X509Certificate2(certificatePath, certificateConfig.Password); + return IsDevelopmentCertificate(certificate) ? certificate : null; + } + else if (!File.Exists(certificatePath)) + { + logger.FailedToLocateDevelopmentCertificateFile(certificatePath); + } + } + catch (CryptographicException) + { + logger.FailedToLoadDevelopmentCertificate(certificatePath); + } + + return null; + } + + private bool IsDevelopmentCertificate(X509Certificate2 certificate) + { + if (!string.Equals(certificate.Subject, "CN=localhost", StringComparison.Ordinal)) + { + return false; + } + + foreach (var ext in certificate.Extensions) + { + if (string.Equals(ext.Oid.Value, CertificateManager.AspNetHttpsOid, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private bool TryGetCertificatePath(out string path) + { + var hostingEnvironment = Options.ApplicationServices.GetRequiredService(); + var appName = hostingEnvironment.ApplicationName; + + // This will go away when we implement + // https://github.com/aspnet/Hosting/issues/1294 + var appData = Environment.GetEnvironmentVariable("APPDATA"); + var home = Environment.GetEnvironmentVariable("HOME"); + var basePath = appData != null ? Path.Combine(appData, "ASP.NET", "https") : null; + basePath = basePath ?? (home != null ? Path.Combine(home, ".aspnet", "https") : null); + path = basePath != null ? Path.Combine(basePath, $"{appName}.pfx") : null; + return path != null; + } + + private X509Certificate2 LoadCertificate(CertificateConfig certInfo, string endpointName) + { + if (certInfo.IsFileCert && certInfo.IsStoreCert) + { + throw new InvalidOperationException(CoreStrings.FormatMultipleCertificateSources(endpointName)); + } + else if (certInfo.IsFileCert) + { + var env = Options.ApplicationServices.GetRequiredService(); + return new X509Certificate2(Path.Combine(env.ContentRootPath, certInfo.Path), certInfo.Password); + } + else if (certInfo.IsStoreCert) + { + return LoadFromStoreCert(certInfo); + } + return null; + } + + private static X509Certificate2 LoadFromStoreCert(CertificateConfig certInfo) + { + var subject = certInfo.Subject; + var storeName = certInfo.Store; + var location = certInfo.Location; + var storeLocation = StoreLocation.CurrentUser; + if (!string.IsNullOrEmpty(location)) + { + storeLocation = (StoreLocation)Enum.Parse(typeof(StoreLocation), location, ignoreCase: true); + } + var allowInvalid = certInfo.AllowInvalid ?? false; + + return CertificateLoader.LoadFromStoreCert(subject, storeName, storeLocation, allowInvalid); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/KestrelServer.cs b/src/Servers/Kestrel/Core/src/KestrelServer.cs new file mode 100644 index 0000000000..9e549a2282 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/KestrelServer.cs @@ -0,0 +1,240 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + public class KestrelServer : IServer + { + private readonly List _transports = new List(); + private readonly Heartbeat _heartbeat; + private readonly IServerAddressesFeature _serverAddresses; + private readonly ITransportFactory _transportFactory; + + private bool _hasStarted; + private int _stopping; + private readonly TaskCompletionSource _stoppedTcs = new TaskCompletionSource(); + + public KestrelServer(IOptions options, ITransportFactory transportFactory, ILoggerFactory loggerFactory) + : this(transportFactory, CreateServiceContext(options, loggerFactory)) + { + } + + // For testing + internal KestrelServer(ITransportFactory transportFactory, ServiceContext serviceContext) + { + if (transportFactory == null) + { + throw new ArgumentNullException(nameof(transportFactory)); + } + + _transportFactory = transportFactory; + ServiceContext = serviceContext; + + var httpHeartbeatManager = new HttpHeartbeatManager(serviceContext.ConnectionManager); + _heartbeat = new Heartbeat( + new IHeartbeatHandler[] { serviceContext.DateHeaderValueManager, httpHeartbeatManager }, + serviceContext.SystemClock, + DebuggerWrapper.Singleton, + Trace); + + Features = new FeatureCollection(); + _serverAddresses = new ServerAddressesFeature(); + Features.Set(_serverAddresses); + } + + private static ServiceContext CreateServiceContext(IOptions options, ILoggerFactory loggerFactory) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + + var serverOptions = options.Value ?? new KestrelServerOptions(); + var logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel"); + var trace = new KestrelTrace(logger); + var connectionManager = new HttpConnectionManager( + trace, + serverOptions.Limits.MaxConcurrentUpgradedConnections); + + var systemClock = new SystemClock(); + var dateHeaderValueManager = new DateHeaderValueManager(systemClock); + + // TODO: This logic will eventually move into the IConnectionHandler and off + // the service context once we get to https://github.com/aspnet/KestrelHttpServer/issues/1662 + PipeScheduler scheduler = null; + switch (serverOptions.ApplicationSchedulingMode) + { + case SchedulingMode.Default: + case SchedulingMode.ThreadPool: + scheduler = PipeScheduler.ThreadPool; + break; + case SchedulingMode.Inline: + scheduler = PipeScheduler.Inline; + break; + default: + throw new NotSupportedException(CoreStrings.FormatUnknownTransportMode(serverOptions.ApplicationSchedulingMode)); + } + + return new ServiceContext + { + Log = trace, + HttpParser = new HttpParser(trace.IsEnabled(LogLevel.Information)), + Scheduler = scheduler, + SystemClock = systemClock, + DateHeaderValueManager = dateHeaderValueManager, + ConnectionManager = connectionManager, + ServerOptions = serverOptions + }; + } + + public IFeatureCollection Features { get; } + + public KestrelServerOptions Options => ServiceContext.ServerOptions; + + private ServiceContext ServiceContext { get; } + + private IKestrelTrace Trace => ServiceContext.Log; + + private HttpConnectionManager ConnectionManager => ServiceContext.ConnectionManager; + + public async Task StartAsync(IHttpApplication application, CancellationToken cancellationToken) + { + try + { + if (!BitConverter.IsLittleEndian) + { + throw new PlatformNotSupportedException(CoreStrings.BigEndianNotSupported); + } + + ValidateOptions(); + + if (_hasStarted) + { + // The server has already started and/or has not been cleaned up yet + throw new InvalidOperationException(CoreStrings.ServerAlreadyStarted); + } + _hasStarted = true; + _heartbeat.Start(); + + async Task OnBind(ListenOptions endpoint) + { + // Add the HTTP middleware as the terminal connection middleware + endpoint.UseHttpServer(endpoint.ConnectionAdapters, ServiceContext, application, endpoint.Protocols); + + var connectionDelegate = endpoint.Build(); + + // Add the connection limit middleware + if (Options.Limits.MaxConcurrentConnections.HasValue) + { + connectionDelegate = new ConnectionLimitMiddleware(connectionDelegate, Options.Limits.MaxConcurrentConnections.Value, Trace).OnConnectionAsync; + } + + var connectionDispatcher = new ConnectionDispatcher(ServiceContext, connectionDelegate); + var transport = _transportFactory.Create(endpoint, connectionDispatcher); + _transports.Add(transport); + + await transport.BindAsync().ConfigureAwait(false); + } + + await AddressBinder.BindAsync(_serverAddresses, Options, Trace, OnBind).ConfigureAwait(false); + } + catch (Exception ex) + { + Trace.LogCritical(0, ex, "Unable to start Kestrel."); + Dispose(); + throw; + } + } + + // Graceful shutdown if possible + public async Task StopAsync(CancellationToken cancellationToken) + { + if (Interlocked.Exchange(ref _stopping, 1) == 1) + { + await _stoppedTcs.Task.ConfigureAwait(false); + return; + } + + try + { + var tasks = new Task[_transports.Count]; + for (int i = 0; i < _transports.Count; i++) + { + tasks[i] = _transports[i].UnbindAsync(); + } + await Task.WhenAll(tasks).ConfigureAwait(false); + + if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false)) + { + Trace.NotAllConnectionsClosedGracefully(); + + if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false)) + { + Trace.NotAllConnectionsAborted(); + } + } + + for (int i = 0; i < _transports.Count; i++) + { + tasks[i] = _transports[i].StopAsync(); + } + await Task.WhenAll(tasks).ConfigureAwait(false); + + _heartbeat.Dispose(); + } + catch (Exception ex) + { + _stoppedTcs.TrySetException(ex); + throw; + } + + _stoppedTcs.TrySetResult(null); + } + + // Ungraceful shutdown + public void Dispose() + { + var cancelledTokenSource = new CancellationTokenSource(); + cancelledTokenSource.Cancel(); + StopAsync(cancelledTokenSource.Token).GetAwaiter().GetResult(); + } + + private void ValidateOptions() + { + Options.ConfigurationLoader?.Load(); + + if (Options.Limits.MaxRequestBufferSize.HasValue && + Options.Limits.MaxRequestBufferSize < Options.Limits.MaxRequestLineSize) + { + throw new InvalidOperationException( + CoreStrings.FormatMaxRequestBufferSmallerThanRequestLineBuffer(Options.Limits.MaxRequestBufferSize.Value, Options.Limits.MaxRequestLineSize)); + } + + if (Options.Limits.MaxRequestBufferSize.HasValue && + Options.Limits.MaxRequestBufferSize < Options.Limits.MaxRequestHeadersTotalSize) + { + throw new InvalidOperationException( + CoreStrings.FormatMaxRequestBufferSmallerThanRequestHeaderBuffer(Options.Limits.MaxRequestBufferSize.Value, Options.Limits.MaxRequestHeadersTotalSize)); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/KestrelServerLimits.cs b/src/Servers/Kestrel/Core/src/KestrelServerLimits.cs new file mode 100644 index 0000000000..f2f8e773ac --- /dev/null +++ b/src/Servers/Kestrel/Core/src/KestrelServerLimits.cs @@ -0,0 +1,291 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + public class KestrelServerLimits + { + // Matches the non-configurable default response buffer size for Kestrel in 1.0.0 + private long? _maxResponseBufferSize = 64 * 1024; + + // Matches the default client_max_body_size in nginx. + // Also large enough that most requests should be under the limit. + private long? _maxRequestBufferSize = 1024 * 1024; + + // Matches the default large_client_header_buffers in nginx. + private int _maxRequestLineSize = 8 * 1024; + + // Matches the default large_client_header_buffers in nginx. + private int _maxRequestHeadersTotalSize = 32 * 1024; + + // Matches the default maxAllowedContentLength in IIS (~28.6 MB) + // https://www.iis.net/configreference/system.webserver/security/requestfiltering/requestlimits#005 + private long? _maxRequestBodySize = 30000000; + + // Matches the default LimitRequestFields in Apache httpd. + private int _maxRequestHeaderCount = 100; + + // Matches the default http.sys connectionTimeout. + private TimeSpan _keepAliveTimeout = TimeSpan.FromMinutes(2); + + private TimeSpan _requestHeadersTimeout = TimeSpan.FromSeconds(30); + + // Unlimited connections are allowed by default. + private long? _maxConcurrentConnections = null; + private long? _maxConcurrentUpgradedConnections = null; + + /// + /// Gets or sets the maximum size of the response buffer before write + /// calls begin to block or return tasks that don't complete until the + /// buffer size drops below the configured limit. + /// Defaults to 65,536 bytes (64 KB). + /// + /// + /// When set to null, the size of the response buffer is unlimited. + /// When set to zero, all write calls will block or return tasks that + /// don't complete until the entire response buffer is flushed. + /// + public long? MaxResponseBufferSize + { + get => _maxResponseBufferSize; + set + { + if (value.HasValue && value.Value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.NonNegativeNumberOrNullRequired); + } + _maxResponseBufferSize = value; + } + } + + /// + /// Gets or sets the maximum size of the request buffer. + /// Defaults to 1,048,576 bytes (1 MB). + /// + /// + /// When set to null, the size of the request buffer is unlimited. + /// + public long? MaxRequestBufferSize + { + get => _maxRequestBufferSize; + set + { + if (value.HasValue && value.Value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveNumberOrNullRequired); + } + _maxRequestBufferSize = value; + } + } + + /// + /// Gets or sets the maximum allowed size for the HTTP request line. + /// Defaults to 8,192 bytes (8 KB). + /// + /// + /// + public int MaxRequestLineSize + { + get => _maxRequestLineSize; + set + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveNumberRequired); + } + _maxRequestLineSize = value; + } + } + + /// + /// Gets or sets the maximum allowed size for the HTTP request headers. + /// Defaults to 32,768 bytes (32 KB). + /// + /// + /// + public int MaxRequestHeadersTotalSize + { + get => _maxRequestHeadersTotalSize; + set + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveNumberRequired); + } + _maxRequestHeadersTotalSize = value; + } + } + + /// + /// Gets or sets the maximum allowed number of headers per HTTP request. + /// Defaults to 100. + /// + /// + /// + public int MaxRequestHeaderCount + { + get => _maxRequestHeaderCount; + set + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveNumberRequired); + } + _maxRequestHeaderCount = value; + } + } + + /// + /// Gets or sets the maximum allowed size of any request body in bytes. + /// When set to null, the maximum request body size is unlimited. + /// This limit has no effect on upgraded connections which are always unlimited. + /// This can be overridden per-request via . + /// Defaults to 30,000,000 bytes, which is approximately 28.6MB. + /// + /// + /// + public long? MaxRequestBodySize + { + get => _maxRequestBodySize; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.NonNegativeNumberOrNullRequired); + } + _maxRequestBodySize = value; + } + } + + /// + /// Gets or sets the keep-alive timeout. + /// Defaults to 2 minutes. + /// + /// + /// + public TimeSpan KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + if (value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveTimeSpanRequired); + } + _keepAliveTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; + } + } + + /// + /// Gets or sets the maximum amount of time the server will spend receiving request headers. + /// Defaults to 30 seconds. + /// + /// + /// + public TimeSpan RequestHeadersTimeout + { + get => _requestHeadersTimeout; + set + { + if (value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveTimeSpanRequired); + } + _requestHeadersTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; + } + } + + /// + /// Gets or sets the maximum number of open connections. When set to null, the number of connections is unlimited. + /// + /// Defaults to null. + /// + /// + /// + /// + /// When a connection is upgraded to another protocol, such as WebSockets, its connection is counted against the + /// limit instead of . + /// + /// + public long? MaxConcurrentConnections + { + get => _maxConcurrentConnections; + set + { + if (value.HasValue && value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.PositiveNumberOrNullRequired); + } + _maxConcurrentConnections = value; + } + } + + /// + /// Gets or sets the maximum number of open, upgraded connections. When set to null, the number of upgraded connections is unlimited. + /// An upgraded connection is one that has been switched from HTTP to another protocol, such as WebSockets. + /// + /// Defaults to null. + /// + /// + /// + /// + /// When a connection is upgraded to another protocol, such as WebSockets, its connection is counted against the + /// limit instead of . + /// + /// + public long? MaxConcurrentUpgradedConnections + { + get => _maxConcurrentUpgradedConnections; + set + { + if (value.HasValue && value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.NonNegativeNumberOrNullRequired); + } + _maxConcurrentUpgradedConnections = value; + } + } + + /// + /// Gets or sets the request body minimum data rate in bytes/second. + /// Setting this property to null indicates no minimum data rate should be enforced. + /// This limit has no effect on upgraded connections which are always unlimited. + /// This can be overridden per-request via . + /// Defaults to 240 bytes/second with a 5 second grace period. + /// + /// + /// + public MinDataRate MinRequestBodyDataRate { get; set; } = + // Matches the default IIS minBytesPerSecond + new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(5)); + + /// + /// Gets or sets the response minimum data rate in bytes/second. + /// Setting this property to null indicates no minimum data rate should be enforced. + /// This limit has no effect on upgraded connections which are always unlimited. + /// This can be overridden per-request via . + /// + /// Defaults to 240 bytes/second with a 5 second grace period. + /// + /// + /// + /// + /// Contrary to the request body minimum data rate, this rate applies to the response status line and headers as well. + /// + /// + /// This rate is enforced per write operation instead of being averaged over the life of the response. Whenever the server + /// writes a chunk of data, a timer is set to the maximum of the grace period set in this property or the length of the write in + /// bytes divided by the data rate (i.e. the maximum amount of time that write should take to complete with the specified data rate). + /// The connection is aborted if the write has not completed by the time that timer expires. + /// + /// + public MinDataRate MinResponseDataRate { get; set; } = + // Matches the default IIS minBytesPerSecond + new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(5)); + } +} diff --git a/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs new file mode 100644 index 0000000000..6c62edc336 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs @@ -0,0 +1,335 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Certificates.Generation; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + /// + /// Provides programmatic configuration of Kestrel-specific features. + /// + public class KestrelServerOptions + { + /// + /// Configures the endpoints that Kestrel should listen to. + /// + /// + /// If this list is empty, the server.urls setting (e.g. UseUrls) is used. + /// + internal List ListenOptions { get; } = new List(); + + /// + /// Gets or sets whether the Server header should be included in each response. + /// + /// + /// Defaults to true. + /// + public bool AddServerHeader { get; set; } = true; + + /// + /// Gets or sets a value that determines how Kestrel should schedule user callbacks. + /// + /// The default mode is + public SchedulingMode ApplicationSchedulingMode { get; set; } = SchedulingMode.Default; + + /// + /// Gets or sets a value that controls whether synchronous IO is allowed for the and + /// + /// + /// Defaults to true. + /// + public bool AllowSynchronousIO { get; set; } = true; + + /// + /// Enables the Listen options callback to resolve and use services registered by the application during startup. + /// Typically initialized by UseKestrel()"/>. + /// + public IServiceProvider ApplicationServices { get; set; } + + /// + /// Provides access to request limit options. + /// + public KestrelServerLimits Limits { get; } = new KestrelServerLimits(); + + /// + /// Provides a configuration source where endpoints will be loaded from on server start. + /// The default is null. + /// + public KestrelConfigurationLoader ConfigurationLoader { get; set; } + + /// + /// A default configuration action for all endpoints. Use for Listen, configuration, the default url, and URLs. + /// + private Action EndpointDefaults { get; set; } = _ => { }; + + /// + /// A default configuration action for all https endpoints. + /// + private Action HttpsDefaults { get; set; } = _ => { }; + + /// + /// The default server certificate for https endpoints. This is applied lazily after HttpsDefaults and user options. + /// + internal X509Certificate2 DefaultCertificate { get; set; } + + /// + /// Has the default dev certificate load been attempted? + /// + internal bool IsDevCertLoaded { get; set; } + + /// + /// Specifies a configuration Action to run for each newly created endpoint. Calling this again will replace + /// the prior action. + /// + public void ConfigureEndpointDefaults(Action configureOptions) + { + EndpointDefaults = configureOptions ?? throw new ArgumentNullException(nameof(configureOptions)); + } + + internal void ApplyEndpointDefaults(ListenOptions listenOptions) + { + listenOptions.KestrelServerOptions = this; + EndpointDefaults(listenOptions); + } + + /// + /// Specifies a configuration Action to run for each newly created https endpoint. Calling this again will replace + /// the prior action. + /// + public void ConfigureHttpsDefaults(Action configureOptions) + { + HttpsDefaults = configureOptions ?? throw new ArgumentNullException(nameof(configureOptions)); + } + + internal void ApplyHttpsDefaults(HttpsConnectionAdapterOptions httpsOptions) + { + HttpsDefaults(httpsOptions); + } + + internal void ApplyDefaultCert(HttpsConnectionAdapterOptions httpsOptions) + { + if (httpsOptions.ServerCertificate != null || httpsOptions.ServerCertificateSelector != null) + { + return; + } + + EnsureDefaultCert(); + + httpsOptions.ServerCertificate = DefaultCertificate; + } + + private void EnsureDefaultCert() + { + if (DefaultCertificate == null && !IsDevCertLoaded) + { + IsDevCertLoaded = true; // Only try once + var logger = ApplicationServices.GetRequiredService>(); + try + { + var certificateManager = new CertificateManager(); + DefaultCertificate = certificateManager.ListCertificates(CertificatePurpose.HTTPS, StoreName.My, StoreLocation.CurrentUser, isValid: true) + .FirstOrDefault(); + + if (DefaultCertificate != null) + { + logger.LocatedDevelopmentCertificate(DefaultCertificate); + } + else + { + logger.UnableToLocateDevelopmentCertificate(); + } + } + catch + { + logger.UnableToLocateDevelopmentCertificate(); + } + } + } + + /// + /// Creates a configuration loader for setting up Kestrel. + /// + public KestrelConfigurationLoader Configure() + { + var loader = new KestrelConfigurationLoader(this, new ConfigurationBuilder().Build()); + ConfigurationLoader = loader; + return loader; + } + + /// + /// Creates a configuration loader for setting up Kestrel that takes an IConfiguration as input. + /// This configuration must be scoped to the configuration section for Kestrel. + /// + public KestrelConfigurationLoader Configure(IConfiguration config) + { + var loader = new KestrelConfigurationLoader(this, config); + ConfigurationLoader = loader; + return loader; + } + + /// + /// Bind to given IP address and port. + /// + public void Listen(IPAddress address, int port) + { + Listen(address, port, _ => { }); + } + + /// + /// Bind to given IP address and port. + /// The callback configures endpoint-specific settings. + /// + public void Listen(IPAddress address, int port, Action configure) + { + if (address == null) + { + throw new ArgumentNullException(nameof(address)); + } + + Listen(new IPEndPoint(address, port), configure); + } + + /// + /// Bind to given IP endpoint. + /// + public void Listen(IPEndPoint endPoint) + { + Listen(endPoint, _ => { }); + } + + /// + /// Bind to given IP address and port. + /// The callback configures endpoint-specific settings. + /// + public void Listen(IPEndPoint endPoint, Action configure) + { + if (endPoint == null) + { + throw new ArgumentNullException(nameof(endPoint)); + } + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var listenOptions = new ListenOptions(endPoint); + ApplyEndpointDefaults(listenOptions); + configure(listenOptions); + ListenOptions.Add(listenOptions); + } + + /// + /// Listens on ::1 and 127.0.0.1 with the given port. Requesting a dynamic port by specifying 0 is not supported + /// for this type of endpoint. + /// + public void ListenLocalhost(int port) => ListenLocalhost(port, options => { }); + + /// + /// Listens on ::1 and 127.0.0.1 with the given port. Requesting a dynamic port by specifying 0 is not supported + /// for this type of endpoint. + /// + public void ListenLocalhost(int port, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var listenOptions = new LocalhostListenOptions(port); + ApplyEndpointDefaults(listenOptions); + configure(listenOptions); + ListenOptions.Add(listenOptions); + } + + /// + /// Listens on all IPs using IPv6 [::], or IPv4 0.0.0.0 if IPv6 is not supported. + /// + public void ListenAnyIP(int port) => ListenAnyIP(port, options => { }); + + /// + /// Listens on all IPs using IPv6 [::], or IPv4 0.0.0.0 if IPv6 is not supported. + /// + public void ListenAnyIP(int port, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var listenOptions = new AnyIPListenOptions(port); + ApplyEndpointDefaults(listenOptions); + configure(listenOptions); + ListenOptions.Add(listenOptions); + } + + /// + /// Bind to given Unix domain socket path. + /// + public void ListenUnixSocket(string socketPath) + { + ListenUnixSocket(socketPath, _ => { }); + } + + /// + /// Bind to given Unix domain socket path. + /// Specify callback to configure endpoint-specific settings. + /// + public void ListenUnixSocket(string socketPath, Action configure) + { + if (socketPath == null) + { + throw new ArgumentNullException(nameof(socketPath)); + } + if (socketPath.Length == 0 || socketPath[0] != '/') + { + throw new ArgumentException(CoreStrings.UnixSocketPathMustBeAbsolute, nameof(socketPath)); + } + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var listenOptions = new ListenOptions(socketPath); + ApplyEndpointDefaults(listenOptions); + configure(listenOptions); + ListenOptions.Add(listenOptions); + } + + /// + /// Open a socket file descriptor. + /// + public void ListenHandle(ulong handle) + { + ListenHandle(handle, _ => { }); + } + + /// + /// Open a socket file descriptor. + /// The callback configures endpoint-specific settings. + /// + public void ListenHandle(ulong handle, Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var listenOptions = new ListenOptions(handle); + ApplyEndpointDefaults(listenOptions); + configure(listenOptions); + ListenOptions.Add(listenOptions); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/ListenOptions.cs b/src/Servers/Kestrel/Core/src/ListenOptions.cs new file mode 100644 index 0000000000..5ae8b31468 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/ListenOptions.cs @@ -0,0 +1,193 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + /// + /// Describes either an , Unix domain socket path, or a file descriptor for an already open + /// socket that Kestrel should bind to or open. + /// + public class ListenOptions : IEndPointInformation, IConnectionBuilder + { + private FileHandleType _handleType; + internal readonly List> _middleware = new List>(); + + internal ListenOptions(IPEndPoint endPoint) + { + Type = ListenType.IPEndPoint; + IPEndPoint = endPoint; + } + + internal ListenOptions(string socketPath) + { + Type = ListenType.SocketPath; + SocketPath = socketPath; + } + + internal ListenOptions(ulong fileHandle) + : this(fileHandle, FileHandleType.Auto) + { + } + + internal ListenOptions(ulong fileHandle, FileHandleType handleType) + { + Type = ListenType.FileHandle; + FileHandle = fileHandle; + switch (handleType) + { + case FileHandleType.Auto: + case FileHandleType.Tcp: + case FileHandleType.Pipe: + _handleType = handleType; + break; + default: + throw new NotSupportedException(); + } + } + + /// + /// The type of interface being described: either an , Unix domain socket path, or a file descriptor. + /// + public ListenType Type { get; } + + public FileHandleType HandleType + { + get => _handleType; + set + { + if (value == _handleType) + { + return; + } + if (Type != ListenType.FileHandle || _handleType != FileHandleType.Auto) + { + throw new InvalidOperationException(); + } + + switch (value) + { + case FileHandleType.Tcp: + case FileHandleType.Pipe: + _handleType = value; + break; + default: + throw new ArgumentException(nameof(HandleType)); + } + } + } + + // IPEndPoint is mutable so port 0 can be updated to the bound port. + /// + /// The to bind to. + /// Only set if the is . + /// + public IPEndPoint IPEndPoint { get; set; } + + /// + /// The absolute path to a Unix domain socket to bind to. + /// Only set if the is . + /// + public string SocketPath { get; } + + /// + /// A file descriptor for the socket to open. + /// Only set if the is . + /// + public ulong FileHandle { get; } + + /// + /// Enables an to resolve and use services registered by the application during startup. + /// Only set if accessed from the callback of a Listen* method. + /// + public KestrelServerOptions KestrelServerOptions { get; internal set; } + + /// + /// Set to false to enable Nagle's algorithm for all connections. + /// + /// + /// Defaults to true. + /// + public bool NoDelay { get; set; } = true; + + /// + /// The protocols enabled on this endpoint. + /// + /// Defaults to HTTP/1.x only. + internal HttpProtocols Protocols { get; set; } = HttpProtocols.Http1; + + /// + /// Gets the that allows each connection + /// to be intercepted and transformed. + /// Configured by the UseHttps() and + /// extension methods. + /// + /// + /// Defaults to empty. + /// + public List ConnectionAdapters { get; } = new List(); + + public IServiceProvider ApplicationServices => KestrelServerOptions?.ApplicationServices; + + /// + /// Gets the name of this endpoint to display on command-line when the web server starts. + /// + internal virtual string GetDisplayName() + { + var scheme = ConnectionAdapters.Any(f => f.IsHttps) + ? "https" + : "http"; + + switch (Type) + { + case ListenType.IPEndPoint: + return $"{scheme}://{IPEndPoint}"; + case ListenType.SocketPath: + return $"{scheme}://unix:{SocketPath}"; + case ListenType.FileHandle: + return $"{scheme}://"; + default: + throw new InvalidOperationException(); + } + } + + public override string ToString() => GetDisplayName(); + + public IConnectionBuilder Use(Func middleware) + { + _middleware.Add(middleware); + return this; + } + + public ConnectionDelegate Build() + { + ConnectionDelegate app = context => + { + return Task.CompletedTask; + }; + + for (int i = _middleware.Count - 1; i >= 0; i--) + { + var component = _middleware[i]; + app = component(app); + } + + return app; + } + + internal virtual async Task BindAsync(AddressBindContext context) + { + await AddressBinder.BindEndpointAsync(this, context).ConfigureAwait(false); + context.Addresses.Add(GetDisplayName()); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs new file mode 100644 index 0000000000..80c6eda3d2 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs @@ -0,0 +1,220 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Hosting +{ + /// + /// Extension methods for that configure Kestrel to use HTTPS for a given endpoint. + /// + public static class ListenOptionsHttpsExtensions + { + /// + /// Configure Kestrel to use HTTPS with the default certificate if available. + /// This will throw if no default certificate is configured. + /// + /// The to configure. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions) => listenOptions.UseHttps(_ => { }); + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The name of a certificate file, relative to the directory that contains the application + /// content files. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, string fileName) + { + var env = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + return listenOptions.UseHttps(new X509Certificate2(Path.Combine(env.ContentRootPath, fileName))); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The name of a certificate file, relative to the directory that contains the application + /// content files. + /// The password required to access the X.509 certificate data. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, string fileName, string password) + { + var env = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + return listenOptions.UseHttps(new X509Certificate2(Path.Combine(env.ContentRootPath, fileName), password)); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The name of a certificate file, relative to the directory that contains the application content files. + /// The password required to access the X.509 certificate data. + /// An Action to configure the . + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, string fileName, string password, + Action configureOptions) + { + var env = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + return listenOptions.UseHttps(new X509Certificate2(Path.Combine(env.ContentRootPath, fileName), password), configureOptions); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The certificate store to load the certificate from. + /// The subject name for the certificate to load. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, StoreName storeName, string subject) + => listenOptions.UseHttps(storeName, subject, allowInvalid: false); + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The certificate store to load the certificate from. + /// The subject name for the certificate to load. + /// Indicates if invalid certificates should be considered, such as self-signed certificates. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, StoreName storeName, string subject, bool allowInvalid) + => listenOptions.UseHttps(storeName, subject, allowInvalid, StoreLocation.CurrentUser); + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The certificate store to load the certificate from. + /// The subject name for the certificate to load. + /// Indicates if invalid certificates should be considered, such as self-signed certificates. + /// The store location to load the certificate from. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, StoreName storeName, string subject, bool allowInvalid, StoreLocation location) + => listenOptions.UseHttps(storeName, subject, allowInvalid, location, configureOptions: _ => { }); + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The certificate store to load the certificate from. + /// The subject name for the certificate to load. + /// Indicates if invalid certificates should be considered, such as self-signed certificates. + /// The store location to load the certificate from. + /// An Action to configure the . + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, StoreName storeName, string subject, bool allowInvalid, StoreLocation location, + Action configureOptions) + { + return listenOptions.UseHttps(CertificateLoader.LoadFromStoreCert(subject, storeName.ToString(), location, allowInvalid), configureOptions); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The X.509 certificate. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, X509Certificate2 serverCertificate) + { + if (serverCertificate == null) + { + throw new ArgumentNullException(nameof(serverCertificate)); + } + + return listenOptions.UseHttps(options => + { + options.ServerCertificate = serverCertificate; + }); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// The X.509 certificate. + /// An Action to configure the . + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, X509Certificate2 serverCertificate, + Action configureOptions) + { + if (serverCertificate == null) + { + throw new ArgumentNullException(nameof(serverCertificate)); + } + + if (configureOptions == null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + + return listenOptions.UseHttps(options => + { + options.ServerCertificate = serverCertificate; + configureOptions(options); + }); + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// An action to configure options for HTTPS. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, Action configureOptions) + { + if (configureOptions == null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + + var options = new HttpsConnectionAdapterOptions(); + listenOptions.KestrelServerOptions.ApplyHttpsDefaults(options); + configureOptions(options); + listenOptions.KestrelServerOptions.ApplyDefaultCert(options); + + if (options.ServerCertificate == null && options.ServerCertificateSelector == null) + { + throw new InvalidOperationException(CoreStrings.NoCertSpecifiedNoDevelopmentCertificateFound); + } + return listenOptions.UseHttps(options); + } + + // Use Https if a default cert is available + internal static bool TryUseHttps(this ListenOptions listenOptions) + { + var options = new HttpsConnectionAdapterOptions(); + listenOptions.KestrelServerOptions.ApplyHttpsDefaults(options); + listenOptions.KestrelServerOptions.ApplyDefaultCert(options); + + if (options.ServerCertificate == null && options.ServerCertificateSelector == null) + { + return false; + } + listenOptions.UseHttps(options); + return true; + } + + /// + /// Configure Kestrel to use HTTPS. + /// + /// The to configure. + /// Options to configure HTTPS. + /// The . + public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConnectionAdapterOptions httpsOptions) + { + var loggerFactory = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + // Set the list of protocols from listen options + httpsOptions.HttpProtocols = listenOptions.Protocols; + listenOptions.ConnectionAdapters.Add(new HttpsConnectionAdapter(httpsOptions, loggerFactory)); + return listenOptions; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs new file mode 100644 index 0000000000..ee28dce63e --- /dev/null +++ b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs @@ -0,0 +1,90 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + internal class LocalhostListenOptions : ListenOptions + { + internal LocalhostListenOptions(int port) + : base(new IPEndPoint(IPAddress.Loopback, port)) + { + if (port == 0) + { + throw new InvalidOperationException(CoreStrings.DynamicPortOnLocalhostNotSupported); + } + } + + /// + /// Gets the name of this endpoint to display on command-line when the web server starts. + /// + internal override string GetDisplayName() + { + var scheme = ConnectionAdapters.Any(f => f.IsHttps) + ? "https" + : "http"; + + return $"{scheme}://localhost:{IPEndPoint.Port}"; + } + + internal override async Task BindAsync(AddressBindContext context) + { + var exceptions = new List(); + + try + { + var v4Options = Clone(IPAddress.Loopback); + await AddressBinder.BindEndpointAsync(v4Options, context).ConfigureAwait(false); + } + catch (Exception ex) when (!(ex is IOException)) + { + context.Logger.LogWarning(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv4 loopback", ex.Message); + exceptions.Add(ex); + } + + try + { + var v6Options = Clone(IPAddress.IPv6Loopback); + await AddressBinder.BindEndpointAsync(v6Options, context).ConfigureAwait(false); + } + catch (Exception ex) when (!(ex is IOException)) + { + context.Logger.LogWarning(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv6 loopback", ex.Message); + exceptions.Add(ex); + } + + if (exceptions.Count == 2) + { + throw new IOException(CoreStrings.FormatAddressBindingFailed(GetDisplayName()), new AggregateException(exceptions)); + } + + // If StartLocalhost doesn't throw, there is at least one listener. + // The port cannot change for "localhost". + context.Addresses.Add(GetDisplayName()); + } + + // used for cloning to two IPEndpoints + internal ListenOptions Clone(IPAddress address) + { + var options = new ListenOptions(new IPEndPoint(address, IPEndPoint.Port)) + { + HandleType = HandleType, + KestrelServerOptions = KestrelServerOptions, + NoDelay = NoDelay, + Protocols = Protocols, + }; + + options._middleware.AddRange(_middleware); + options.ConnectionAdapters.AddRange(ConnectionAdapters); + return options; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj b/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj new file mode 100644 index 0000000000..dae8056002 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Microsoft.AspNetCore.Server.Kestrel.Core.csproj @@ -0,0 +1,38 @@ + + + + Core components of ASP.NET Core Kestrel cross-platform web server. + netstandard2.0;netcoreapp2.1 + true + aspnetcore;kestrel + true + CS1591;$(NoWarn) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Core/src/MinDataRate.cs b/src/Servers/Kestrel/Core/src/MinDataRate.cs new file mode 100644 index 0000000000..0e320b37f1 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/MinDataRate.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + public class MinDataRate + { + /// + /// Creates a new instance of . + /// + /// The minimum rate in bytes/second at which data should be processed. + /// The amount of time to delay enforcement of , + /// starting at the time data is first read or written. + public MinDataRate(double bytesPerSecond, TimeSpan gracePeriod) + { + if (bytesPerSecond <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bytesPerSecond), CoreStrings.PositiveNumberOrNullMinDataRateRequired); + } + + if (gracePeriod <= Heartbeat.Interval) + { + throw new ArgumentOutOfRangeException(nameof(gracePeriod), CoreStrings.FormatMinimumGracePeriodRequired(Heartbeat.Interval.TotalSeconds)); + } + + BytesPerSecond = bytesPerSecond; + GracePeriod = gracePeriod; + } + + /// + /// The minimum rate in bytes/second at which data should be processed. + /// + public double BytesPerSecond { get; } + + /// + /// The amount of time to delay enforcement of , + /// starting at the time data is first read or written. + /// + public TimeSpan GracePeriod { get; } + } +} diff --git a/src/Servers/Kestrel/Core/src/Properties/AssemblyInfo.cs b/src/Servers/Kestrel/Core/src/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1806c94cf2 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Properties/AssemblyInfo.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Libuv.FunctionalTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Sockets.FunctionalTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.Core.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.Performance, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Http2SampleApp, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("PlatformBenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Servers/Kestrel/Core/src/Properties/CoreStrings.Designer.cs b/src/Servers/Kestrel/Core/src/Properties/CoreStrings.Designer.cs new file mode 100644 index 0000000000..c813873491 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Properties/CoreStrings.Designer.cs @@ -0,0 +1,1896 @@ +// +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + using System.Globalization; + using System.Reflection; + using System.Resources; + + internal static class CoreStrings + { + private static readonly ResourceManager _resourceManager + = new ResourceManager("Microsoft.AspNetCore.Server.Kestrel.Core.CoreStrings", typeof(CoreStrings).GetTypeInfo().Assembly); + + /// + /// Bad request. + /// + internal static string BadRequest + { + get => GetString("BadRequest"); + } + + /// + /// Bad request. + /// + internal static string FormatBadRequest() + => GetString("BadRequest"); + + /// + /// Bad chunk size data. + /// + internal static string BadRequest_BadChunkSizeData + { + get => GetString("BadRequest_BadChunkSizeData"); + } + + /// + /// Bad chunk size data. + /// + internal static string FormatBadRequest_BadChunkSizeData() + => GetString("BadRequest_BadChunkSizeData"); + + /// + /// Bad chunk suffix. + /// + internal static string BadRequest_BadChunkSuffix + { + get => GetString("BadRequest_BadChunkSuffix"); + } + + /// + /// Bad chunk suffix. + /// + internal static string FormatBadRequest_BadChunkSuffix() + => GetString("BadRequest_BadChunkSuffix"); + + /// + /// Chunked request incomplete. + /// + internal static string BadRequest_ChunkedRequestIncomplete + { + get => GetString("BadRequest_ChunkedRequestIncomplete"); + } + + /// + /// Chunked request incomplete. + /// + internal static string FormatBadRequest_ChunkedRequestIncomplete() + => GetString("BadRequest_ChunkedRequestIncomplete"); + + /// + /// The message body length cannot be determined because the final transfer coding was set to '{detail}' instead of 'chunked'. + /// + internal static string BadRequest_FinalTransferCodingNotChunked + { + get => GetString("BadRequest_FinalTransferCodingNotChunked"); + } + + /// + /// The message body length cannot be determined because the final transfer coding was set to '{detail}' instead of 'chunked'. + /// + internal static string FormatBadRequest_FinalTransferCodingNotChunked(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_FinalTransferCodingNotChunked", "detail"), detail); + + /// + /// Request headers too long. + /// + internal static string BadRequest_HeadersExceedMaxTotalSize + { + get => GetString("BadRequest_HeadersExceedMaxTotalSize"); + } + + /// + /// Request headers too long. + /// + internal static string FormatBadRequest_HeadersExceedMaxTotalSize() + => GetString("BadRequest_HeadersExceedMaxTotalSize"); + + /// + /// Invalid characters in header name. + /// + internal static string BadRequest_InvalidCharactersInHeaderName + { + get => GetString("BadRequest_InvalidCharactersInHeaderName"); + } + + /// + /// Invalid characters in header name. + /// + internal static string FormatBadRequest_InvalidCharactersInHeaderName() + => GetString("BadRequest_InvalidCharactersInHeaderName"); + + /// + /// Invalid content length: {detail} + /// + internal static string BadRequest_InvalidContentLength_Detail + { + get => GetString("BadRequest_InvalidContentLength_Detail"); + } + + /// + /// Invalid content length: {detail} + /// + internal static string FormatBadRequest_InvalidContentLength_Detail(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_InvalidContentLength_Detail", "detail"), detail); + + /// + /// Invalid Host header. + /// + internal static string BadRequest_InvalidHostHeader + { + get => GetString("BadRequest_InvalidHostHeader"); + } + + /// + /// Invalid Host header. + /// + internal static string FormatBadRequest_InvalidHostHeader() + => GetString("BadRequest_InvalidHostHeader"); + + /// + /// Invalid Host header: '{detail}' + /// + internal static string BadRequest_InvalidHostHeader_Detail + { + get => GetString("BadRequest_InvalidHostHeader_Detail"); + } + + /// + /// Invalid Host header: '{detail}' + /// + internal static string FormatBadRequest_InvalidHostHeader_Detail(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_InvalidHostHeader_Detail", "detail"), detail); + + /// + /// Invalid request headers: missing final CRLF in header fields. + /// + internal static string BadRequest_InvalidRequestHeadersNoCRLF + { + get => GetString("BadRequest_InvalidRequestHeadersNoCRLF"); + } + + /// + /// Invalid request headers: missing final CRLF in header fields. + /// + internal static string FormatBadRequest_InvalidRequestHeadersNoCRLF() + => GetString("BadRequest_InvalidRequestHeadersNoCRLF"); + + /// + /// Invalid request header: '{detail}' + /// + internal static string BadRequest_InvalidRequestHeader_Detail + { + get => GetString("BadRequest_InvalidRequestHeader_Detail"); + } + + /// + /// Invalid request header: '{detail}' + /// + internal static string FormatBadRequest_InvalidRequestHeader_Detail(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_InvalidRequestHeader_Detail", "detail"), detail); + + /// + /// Invalid request line. + /// + internal static string BadRequest_InvalidRequestLine + { + get => GetString("BadRequest_InvalidRequestLine"); + } + + /// + /// Invalid request line. + /// + internal static string FormatBadRequest_InvalidRequestLine() + => GetString("BadRequest_InvalidRequestLine"); + + /// + /// Invalid request line: '{detail}' + /// + internal static string BadRequest_InvalidRequestLine_Detail + { + get => GetString("BadRequest_InvalidRequestLine_Detail"); + } + + /// + /// Invalid request line: '{detail}' + /// + internal static string FormatBadRequest_InvalidRequestLine_Detail(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_InvalidRequestLine_Detail", "detail"), detail); + + /// + /// Invalid request target: '{detail}' + /// + internal static string BadRequest_InvalidRequestTarget_Detail + { + get => GetString("BadRequest_InvalidRequestTarget_Detail"); + } + + /// + /// Invalid request target: '{detail}' + /// + internal static string FormatBadRequest_InvalidRequestTarget_Detail(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_InvalidRequestTarget_Detail", "detail"), detail); + + /// + /// {detail} request contains no Content-Length or Transfer-Encoding header. + /// + internal static string BadRequest_LengthRequired + { + get => GetString("BadRequest_LengthRequired"); + } + + /// + /// {detail} request contains no Content-Length or Transfer-Encoding header. + /// + internal static string FormatBadRequest_LengthRequired(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_LengthRequired", "detail"), detail); + + /// + /// {detail} request contains no Content-Length header. + /// + internal static string BadRequest_LengthRequiredHttp10 + { + get => GetString("BadRequest_LengthRequiredHttp10"); + } + + /// + /// {detail} request contains no Content-Length header. + /// + internal static string FormatBadRequest_LengthRequiredHttp10(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_LengthRequiredHttp10", "detail"), detail); + + /// + /// Malformed request: invalid headers. + /// + internal static string BadRequest_MalformedRequestInvalidHeaders + { + get => GetString("BadRequest_MalformedRequestInvalidHeaders"); + } + + /// + /// Malformed request: invalid headers. + /// + internal static string FormatBadRequest_MalformedRequestInvalidHeaders() + => GetString("BadRequest_MalformedRequestInvalidHeaders"); + + /// + /// Method not allowed. + /// + internal static string BadRequest_MethodNotAllowed + { + get => GetString("BadRequest_MethodNotAllowed"); + } + + /// + /// Method not allowed. + /// + internal static string FormatBadRequest_MethodNotAllowed() + => GetString("BadRequest_MethodNotAllowed"); + + /// + /// Request is missing Host header. + /// + internal static string BadRequest_MissingHostHeader + { + get => GetString("BadRequest_MissingHostHeader"); + } + + /// + /// Request is missing Host header. + /// + internal static string FormatBadRequest_MissingHostHeader() + => GetString("BadRequest_MissingHostHeader"); + + /// + /// Multiple Content-Length headers. + /// + internal static string BadRequest_MultipleContentLengths + { + get => GetString("BadRequest_MultipleContentLengths"); + } + + /// + /// Multiple Content-Length headers. + /// + internal static string FormatBadRequest_MultipleContentLengths() + => GetString("BadRequest_MultipleContentLengths"); + + /// + /// Multiple Host headers. + /// + internal static string BadRequest_MultipleHostHeaders + { + get => GetString("BadRequest_MultipleHostHeaders"); + } + + /// + /// Multiple Host headers. + /// + internal static string FormatBadRequest_MultipleHostHeaders() + => GetString("BadRequest_MultipleHostHeaders"); + + /// + /// Request line too long. + /// + internal static string BadRequest_RequestLineTooLong + { + get => GetString("BadRequest_RequestLineTooLong"); + } + + /// + /// Request line too long. + /// + internal static string FormatBadRequest_RequestLineTooLong() + => GetString("BadRequest_RequestLineTooLong"); + + /// + /// Reading the request headers timed out. + /// + internal static string BadRequest_RequestHeadersTimeout + { + get => GetString("BadRequest_RequestHeadersTimeout"); + } + + /// + /// Reading the request headers timed out. + /// + internal static string FormatBadRequest_RequestHeadersTimeout() + => GetString("BadRequest_RequestHeadersTimeout"); + + /// + /// Request contains too many headers. + /// + internal static string BadRequest_TooManyHeaders + { + get => GetString("BadRequest_TooManyHeaders"); + } + + /// + /// Request contains too many headers. + /// + internal static string FormatBadRequest_TooManyHeaders() + => GetString("BadRequest_TooManyHeaders"); + + /// + /// Unexpected end of request content. + /// + internal static string BadRequest_UnexpectedEndOfRequestContent + { + get => GetString("BadRequest_UnexpectedEndOfRequestContent"); + } + + /// + /// Unexpected end of request content. + /// + internal static string FormatBadRequest_UnexpectedEndOfRequestContent() + => GetString("BadRequest_UnexpectedEndOfRequestContent"); + + /// + /// Unrecognized HTTP version: '{detail}' + /// + internal static string BadRequest_UnrecognizedHTTPVersion + { + get => GetString("BadRequest_UnrecognizedHTTPVersion"); + } + + /// + /// Unrecognized HTTP version: '{detail}' + /// + internal static string FormatBadRequest_UnrecognizedHTTPVersion(object detail) + => string.Format(CultureInfo.CurrentCulture, GetString("BadRequest_UnrecognizedHTTPVersion", "detail"), detail); + + /// + /// Requests with 'Connection: Upgrade' cannot have content in the request body. + /// + internal static string BadRequest_UpgradeRequestCannotHavePayload + { + get => GetString("BadRequest_UpgradeRequestCannotHavePayload"); + } + + /// + /// Requests with 'Connection: Upgrade' cannot have content in the request body. + /// + internal static string FormatBadRequest_UpgradeRequestCannotHavePayload() + => GetString("BadRequest_UpgradeRequestCannotHavePayload"); + + /// + /// Failed to bind to http://[::]:{port} (IPv6Any). Attempting to bind to http://0.0.0.0:{port} instead. + /// + internal static string FallbackToIPv4Any + { + get => GetString("FallbackToIPv4Any"); + } + + /// + /// Failed to bind to http://[::]:{port} (IPv6Any). Attempting to bind to http://0.0.0.0:{port} instead. + /// + internal static string FormatFallbackToIPv4Any(object port) + => string.Format(CultureInfo.CurrentCulture, GetString("FallbackToIPv4Any", "port"), port); + + /// + /// Cannot write to response body after connection has been upgraded. + /// + internal static string ResponseStreamWasUpgraded + { + get => GetString("ResponseStreamWasUpgraded"); + } + + /// + /// Cannot write to response body after connection has been upgraded. + /// + internal static string FormatResponseStreamWasUpgraded() + => GetString("ResponseStreamWasUpgraded"); + + /// + /// Kestrel does not support big-endian architectures. + /// + internal static string BigEndianNotSupported + { + get => GetString("BigEndianNotSupported"); + } + + /// + /// Kestrel does not support big-endian architectures. + /// + internal static string FormatBigEndianNotSupported() + => GetString("BigEndianNotSupported"); + + /// + /// Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request header size ({requestHeaderSize}). + /// + internal static string MaxRequestBufferSmallerThanRequestHeaderBuffer + { + get => GetString("MaxRequestBufferSmallerThanRequestHeaderBuffer"); + } + + /// + /// Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request header size ({requestHeaderSize}). + /// + internal static string FormatMaxRequestBufferSmallerThanRequestHeaderBuffer(object requestBufferSize, object requestHeaderSize) + => string.Format(CultureInfo.CurrentCulture, GetString("MaxRequestBufferSmallerThanRequestHeaderBuffer", "requestBufferSize", "requestHeaderSize"), requestBufferSize, requestHeaderSize); + + /// + /// Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request line size ({requestLineSize}). + /// + internal static string MaxRequestBufferSmallerThanRequestLineBuffer + { + get => GetString("MaxRequestBufferSmallerThanRequestLineBuffer"); + } + + /// + /// Maximum request buffer size ({requestBufferSize}) must be greater than or equal to maximum request line size ({requestLineSize}). + /// + internal static string FormatMaxRequestBufferSmallerThanRequestLineBuffer(object requestBufferSize, object requestLineSize) + => string.Format(CultureInfo.CurrentCulture, GetString("MaxRequestBufferSmallerThanRequestLineBuffer", "requestBufferSize", "requestLineSize"), requestBufferSize, requestLineSize); + + /// + /// Server has already started. + /// + internal static string ServerAlreadyStarted + { + get => GetString("ServerAlreadyStarted"); + } + + /// + /// Server has already started. + /// + internal static string FormatServerAlreadyStarted() + => GetString("ServerAlreadyStarted"); + + /// + /// Unknown transport mode: '{mode}'. + /// + internal static string UnknownTransportMode + { + get => GetString("UnknownTransportMode"); + } + + /// + /// Unknown transport mode: '{mode}'. + /// + internal static string FormatUnknownTransportMode(object mode) + => string.Format(CultureInfo.CurrentCulture, GetString("UnknownTransportMode", "mode"), mode); + + /// + /// Invalid non-ASCII or control character in header: {character} + /// + internal static string InvalidAsciiOrControlChar + { + get => GetString("InvalidAsciiOrControlChar"); + } + + /// + /// Invalid non-ASCII or control character in header: {character} + /// + internal static string FormatInvalidAsciiOrControlChar(object character) + => string.Format(CultureInfo.CurrentCulture, GetString("InvalidAsciiOrControlChar", "character"), character); + + /// + /// Invalid Content-Length: "{value}". Value must be a positive integral number. + /// + internal static string InvalidContentLength_InvalidNumber + { + get => GetString("InvalidContentLength_InvalidNumber"); + } + + /// + /// Invalid Content-Length: "{value}". Value must be a positive integral number. + /// + internal static string FormatInvalidContentLength_InvalidNumber(object value) + => string.Format(CultureInfo.CurrentCulture, GetString("InvalidContentLength_InvalidNumber", "value"), value); + + /// + /// Value must be null or a non-negative number. + /// + internal static string NonNegativeNumberOrNullRequired + { + get => GetString("NonNegativeNumberOrNullRequired"); + } + + /// + /// Value must be null or a non-negative number. + /// + internal static string FormatNonNegativeNumberOrNullRequired() + => GetString("NonNegativeNumberOrNullRequired"); + + /// + /// Value must be a non-negative number. + /// + internal static string NonNegativeNumberRequired + { + get => GetString("NonNegativeNumberRequired"); + } + + /// + /// Value must be a non-negative number. + /// + internal static string FormatNonNegativeNumberRequired() + => GetString("NonNegativeNumberRequired"); + + /// + /// Value must be a positive number. + /// + internal static string PositiveNumberRequired + { + get => GetString("PositiveNumberRequired"); + } + + /// + /// Value must be a positive number. + /// + internal static string FormatPositiveNumberRequired() + => GetString("PositiveNumberRequired"); + + /// + /// Value must be null or a positive number. + /// + internal static string PositiveNumberOrNullRequired + { + get => GetString("PositiveNumberOrNullRequired"); + } + + /// + /// Value must be null or a positive number. + /// + internal static string FormatPositiveNumberOrNullRequired() + => GetString("PositiveNumberOrNullRequired"); + + /// + /// Unix socket path must be absolute. + /// + internal static string UnixSocketPathMustBeAbsolute + { + get => GetString("UnixSocketPathMustBeAbsolute"); + } + + /// + /// Unix socket path must be absolute. + /// + internal static string FormatUnixSocketPathMustBeAbsolute() + => GetString("UnixSocketPathMustBeAbsolute"); + + /// + /// Failed to bind to address {address}. + /// + internal static string AddressBindingFailed + { + get => GetString("AddressBindingFailed"); + } + + /// + /// Failed to bind to address {address}. + /// + internal static string FormatAddressBindingFailed(object address) + => string.Format(CultureInfo.CurrentCulture, GetString("AddressBindingFailed", "address"), address); + + /// + /// No listening endpoints were configured. Binding to {address} by default. + /// + internal static string BindingToDefaultAddress + { + get => GetString("BindingToDefaultAddress"); + } + + /// + /// No listening endpoints were configured. Binding to {address} by default. + /// + internal static string FormatBindingToDefaultAddress(object address) + => string.Format(CultureInfo.CurrentCulture, GetString("BindingToDefaultAddress", "address"), address); + + /// + /// HTTPS endpoints can only be configured using {methodName}. + /// + internal static string ConfigureHttpsFromMethodCall + { + get => GetString("ConfigureHttpsFromMethodCall"); + } + + /// + /// HTTPS endpoints can only be configured using {methodName}. + /// + internal static string FormatConfigureHttpsFromMethodCall(object methodName) + => string.Format(CultureInfo.CurrentCulture, GetString("ConfigureHttpsFromMethodCall", "methodName"), methodName); + + /// + /// A path base can only be configured using {methodName}. + /// + internal static string ConfigurePathBaseFromMethodCall + { + get => GetString("ConfigurePathBaseFromMethodCall"); + } + + /// + /// A path base can only be configured using {methodName}. + /// + internal static string FormatConfigurePathBaseFromMethodCall(object methodName) + => string.Format(CultureInfo.CurrentCulture, GetString("ConfigurePathBaseFromMethodCall", "methodName"), methodName); + + /// + /// Dynamic port binding is not supported when binding to localhost. You must either bind to 127.0.0.1:0 or [::1]:0, or both. + /// + internal static string DynamicPortOnLocalhostNotSupported + { + get => GetString("DynamicPortOnLocalhostNotSupported"); + } + + /// + /// Dynamic port binding is not supported when binding to localhost. You must either bind to 127.0.0.1:0 or [::1]:0, or both. + /// + internal static string FormatDynamicPortOnLocalhostNotSupported() + => GetString("DynamicPortOnLocalhostNotSupported"); + + /// + /// Failed to bind to address {endpoint}: address already in use. + /// + internal static string EndpointAlreadyInUse + { + get => GetString("EndpointAlreadyInUse"); + } + + /// + /// Failed to bind to address {endpoint}: address already in use. + /// + internal static string FormatEndpointAlreadyInUse(object endpoint) + => string.Format(CultureInfo.CurrentCulture, GetString("EndpointAlreadyInUse", "endpoint"), endpoint); + + /// + /// Invalid URL: '{url}'. + /// + internal static string InvalidUrl + { + get => GetString("InvalidUrl"); + } + + /// + /// Invalid URL: '{url}'. + /// + internal static string FormatInvalidUrl(object url) + => string.Format(CultureInfo.CurrentCulture, GetString("InvalidUrl", "url"), url); + + /// + /// Unable to bind to {address} on the {interfaceName} interface: '{error}'. + /// + internal static string NetworkInterfaceBindingFailed + { + get => GetString("NetworkInterfaceBindingFailed"); + } + + /// + /// Unable to bind to {address} on the {interfaceName} interface: '{error}'. + /// + internal static string FormatNetworkInterfaceBindingFailed(object address, object interfaceName, object error) + => string.Format(CultureInfo.CurrentCulture, GetString("NetworkInterfaceBindingFailed", "address", "interfaceName", "error"), address, interfaceName, error); + + /// + /// Overriding address(es) '{addresses}'. Binding to endpoints defined in {methodName} instead. + /// + internal static string OverridingWithKestrelOptions + { + get => GetString("OverridingWithKestrelOptions"); + } + + /// + /// Overriding address(es) '{addresses}'. Binding to endpoints defined in {methodName} instead. + /// + internal static string FormatOverridingWithKestrelOptions(object addresses, object methodName) + => string.Format(CultureInfo.CurrentCulture, GetString("OverridingWithKestrelOptions", "addresses", "methodName"), addresses, methodName); + + /// + /// Overriding endpoints defined in UseKestrel() because {settingName} is set to true. Binding to address(es) '{addresses}' instead. + /// + internal static string OverridingWithPreferHostingUrls + { + get => GetString("OverridingWithPreferHostingUrls"); + } + + /// + /// Overriding endpoints defined in UseKestrel() because {settingName} is set to true. Binding to address(es) '{addresses}' instead. + /// + internal static string FormatOverridingWithPreferHostingUrls(object settingName, object addresses) + => string.Format(CultureInfo.CurrentCulture, GetString("OverridingWithPreferHostingUrls", "settingName", "addresses"), settingName, addresses); + + /// + /// Unrecognized scheme in server address '{address}'. Only 'http://' is supported. + /// + internal static string UnsupportedAddressScheme + { + get => GetString("UnsupportedAddressScheme"); + } + + /// + /// Unrecognized scheme in server address '{address}'. Only 'http://' is supported. + /// + internal static string FormatUnsupportedAddressScheme(object address) + => string.Format(CultureInfo.CurrentCulture, GetString("UnsupportedAddressScheme", "address"), address); + + /// + /// Headers are read-only, response has already started. + /// + internal static string HeadersAreReadOnly + { + get => GetString("HeadersAreReadOnly"); + } + + /// + /// Headers are read-only, response has already started. + /// + internal static string FormatHeadersAreReadOnly() + => GetString("HeadersAreReadOnly"); + + /// + /// An item with the same key has already been added. + /// + internal static string KeyAlreadyExists + { + get => GetString("KeyAlreadyExists"); + } + + /// + /// An item with the same key has already been added. + /// + internal static string FormatKeyAlreadyExists() + => GetString("KeyAlreadyExists"); + + /// + /// Setting the header {name} is not allowed on responses with status code {statusCode}. + /// + internal static string HeaderNotAllowedOnResponse + { + get => GetString("HeaderNotAllowedOnResponse"); + } + + /// + /// Setting the header {name} is not allowed on responses with status code {statusCode}. + /// + internal static string FormatHeaderNotAllowedOnResponse(object name, object statusCode) + => string.Format(CultureInfo.CurrentCulture, GetString("HeaderNotAllowedOnResponse", "name", "statusCode"), name, statusCode); + + /// + /// {name} cannot be set because the response has already started. + /// + internal static string ParameterReadOnlyAfterResponseStarted + { + get => GetString("ParameterReadOnlyAfterResponseStarted"); + } + + /// + /// {name} cannot be set because the response has already started. + /// + internal static string FormatParameterReadOnlyAfterResponseStarted(object name) + => string.Format(CultureInfo.CurrentCulture, GetString("ParameterReadOnlyAfterResponseStarted", "name"), name); + + /// + /// Request processing didn't complete within the shutdown timeout. + /// + internal static string RequestProcessingAborted + { + get => GetString("RequestProcessingAborted"); + } + + /// + /// Request processing didn't complete within the shutdown timeout. + /// + internal static string FormatRequestProcessingAborted() + => GetString("RequestProcessingAborted"); + + /// + /// Response Content-Length mismatch: too few bytes written ({written} of {expected}). + /// + internal static string TooFewBytesWritten + { + get => GetString("TooFewBytesWritten"); + } + + /// + /// Response Content-Length mismatch: too few bytes written ({written} of {expected}). + /// + internal static string FormatTooFewBytesWritten(object written, object expected) + => string.Format(CultureInfo.CurrentCulture, GetString("TooFewBytesWritten", "written", "expected"), written, expected); + + /// + /// Response Content-Length mismatch: too many bytes written ({written} of {expected}). + /// + internal static string TooManyBytesWritten + { + get => GetString("TooManyBytesWritten"); + } + + /// + /// Response Content-Length mismatch: too many bytes written ({written} of {expected}). + /// + internal static string FormatTooManyBytesWritten(object written, object expected) + => string.Format(CultureInfo.CurrentCulture, GetString("TooManyBytesWritten", "written", "expected"), written, expected); + + /// + /// The response has been aborted due to an unhandled application exception. + /// + internal static string UnhandledApplicationException + { + get => GetString("UnhandledApplicationException"); + } + + /// + /// The response has been aborted due to an unhandled application exception. + /// + internal static string FormatUnhandledApplicationException() + => GetString("UnhandledApplicationException"); + + /// + /// Writing to the response body is invalid for responses with status code {statusCode}. + /// + internal static string WritingToResponseBodyNotSupported + { + get => GetString("WritingToResponseBodyNotSupported"); + } + + /// + /// Writing to the response body is invalid for responses with status code {statusCode}. + /// + internal static string FormatWritingToResponseBodyNotSupported(object statusCode) + => string.Format(CultureInfo.CurrentCulture, GetString("WritingToResponseBodyNotSupported", "statusCode"), statusCode); + + /// + /// Connection shutdown abnormally. + /// + internal static string ConnectionShutdownError + { + get => GetString("ConnectionShutdownError"); + } + + /// + /// Connection shutdown abnormally. + /// + internal static string FormatConnectionShutdownError() + => GetString("ConnectionShutdownError"); + + /// + /// Connection processing ended abnormally. + /// + internal static string RequestProcessingEndError + { + get => GetString("RequestProcessingEndError"); + } + + /// + /// Connection processing ended abnormally. + /// + internal static string FormatRequestProcessingEndError() + => GetString("RequestProcessingEndError"); + + /// + /// Cannot upgrade a non-upgradable request. Check IHttpUpgradeFeature.IsUpgradableRequest to determine if a request can be upgraded. + /// + internal static string CannotUpgradeNonUpgradableRequest + { + get => GetString("CannotUpgradeNonUpgradableRequest"); + } + + /// + /// Cannot upgrade a non-upgradable request. Check IHttpUpgradeFeature.IsUpgradableRequest to determine if a request can be upgraded. + /// + internal static string FormatCannotUpgradeNonUpgradableRequest() + => GetString("CannotUpgradeNonUpgradableRequest"); + + /// + /// Request cannot be upgraded because the server has already opened the maximum number of upgraded connections. + /// + internal static string UpgradedConnectionLimitReached + { + get => GetString("UpgradedConnectionLimitReached"); + } + + /// + /// Request cannot be upgraded because the server has already opened the maximum number of upgraded connections. + /// + internal static string FormatUpgradedConnectionLimitReached() + => GetString("UpgradedConnectionLimitReached"); + + /// + /// IHttpUpgradeFeature.UpgradeAsync was already called and can only be called once per connection. + /// + internal static string UpgradeCannotBeCalledMultipleTimes + { + get => GetString("UpgradeCannotBeCalledMultipleTimes"); + } + + /// + /// IHttpUpgradeFeature.UpgradeAsync was already called and can only be called once per connection. + /// + internal static string FormatUpgradeCannotBeCalledMultipleTimes() + => GetString("UpgradeCannotBeCalledMultipleTimes"); + + /// + /// Request body too large. + /// + internal static string BadRequest_RequestBodyTooLarge + { + get => GetString("BadRequest_RequestBodyTooLarge"); + } + + /// + /// Request body too large. + /// + internal static string FormatBadRequest_RequestBodyTooLarge() + => GetString("BadRequest_RequestBodyTooLarge"); + + /// + /// The maximum request body size cannot be modified after the app has already started reading from the request body. + /// + internal static string MaxRequestBodySizeCannotBeModifiedAfterRead + { + get => GetString("MaxRequestBodySizeCannotBeModifiedAfterRead"); + } + + /// + /// The maximum request body size cannot be modified after the app has already started reading from the request body. + /// + internal static string FormatMaxRequestBodySizeCannotBeModifiedAfterRead() + => GetString("MaxRequestBodySizeCannotBeModifiedAfterRead"); + + /// + /// The maximum request body size cannot be modified after the request has been upgraded. + /// + internal static string MaxRequestBodySizeCannotBeModifiedForUpgradedRequests + { + get => GetString("MaxRequestBodySizeCannotBeModifiedForUpgradedRequests"); + } + + /// + /// The maximum request body size cannot be modified after the request has been upgraded. + /// + internal static string FormatMaxRequestBodySizeCannotBeModifiedForUpgradedRequests() + => GetString("MaxRequestBodySizeCannotBeModifiedForUpgradedRequests"); + + /// + /// Value must be a positive TimeSpan. + /// + internal static string PositiveTimeSpanRequired + { + get => GetString("PositiveTimeSpanRequired"); + } + + /// + /// Value must be a positive TimeSpan. + /// + internal static string FormatPositiveTimeSpanRequired() + => GetString("PositiveTimeSpanRequired"); + + /// + /// Value must be a non-negative TimeSpan. + /// + internal static string NonNegativeTimeSpanRequired + { + get => GetString("NonNegativeTimeSpanRequired"); + } + + /// + /// Value must be a non-negative TimeSpan. + /// + internal static string FormatNonNegativeTimeSpanRequired() + => GetString("NonNegativeTimeSpanRequired"); + + /// + /// The request body rate enforcement grace period must be greater than {heartbeatInterval} second. + /// + internal static string MinimumGracePeriodRequired + { + get => GetString("MinimumGracePeriodRequired"); + } + + /// + /// The request body rate enforcement grace period must be greater than {heartbeatInterval} second. + /// + internal static string FormatMinimumGracePeriodRequired(object heartbeatInterval) + => string.Format(CultureInfo.CurrentCulture, GetString("MinimumGracePeriodRequired", "heartbeatInterval"), heartbeatInterval); + + /// + /// Synchronous operations are disallowed. Call ReadAsync or set AllowSynchronousIO to true instead. + /// + internal static string SynchronousReadsDisallowed + { + get => GetString("SynchronousReadsDisallowed"); + } + + /// + /// Synchronous operations are disallowed. Call ReadAsync or set AllowSynchronousIO to true instead. + /// + internal static string FormatSynchronousReadsDisallowed() + => GetString("SynchronousReadsDisallowed"); + + /// + /// Synchronous operations are disallowed. Call WriteAsync or set AllowSynchronousIO to true instead. + /// + internal static string SynchronousWritesDisallowed + { + get => GetString("SynchronousWritesDisallowed"); + } + + /// + /// Synchronous operations are disallowed. Call WriteAsync or set AllowSynchronousIO to true instead. + /// + internal static string FormatSynchronousWritesDisallowed() + => GetString("SynchronousWritesDisallowed"); + + /// + /// Value must be a positive number. To disable a minimum data rate, use null where a MinDataRate instance is expected. + /// + internal static string PositiveNumberOrNullMinDataRateRequired + { + get => GetString("PositiveNumberOrNullMinDataRateRequired"); + } + + /// + /// Value must be a positive number. To disable a minimum data rate, use null where a MinDataRate instance is expected. + /// + internal static string FormatPositiveNumberOrNullMinDataRateRequired() + => GetString("PositiveNumberOrNullMinDataRateRequired"); + + /// + /// Concurrent timeouts are not supported. + /// + internal static string ConcurrentTimeoutsNotSupported + { + get => GetString("ConcurrentTimeoutsNotSupported"); + } + + /// + /// Concurrent timeouts are not supported. + /// + internal static string FormatConcurrentTimeoutsNotSupported() + => GetString("ConcurrentTimeoutsNotSupported"); + + /// + /// Timespan must be positive and finite. + /// + internal static string PositiveFiniteTimeSpanRequired + { + get => GetString("PositiveFiniteTimeSpanRequired"); + } + + /// + /// Timespan must be positive and finite. + /// + internal static string FormatPositiveFiniteTimeSpanRequired() + => GetString("PositiveFiniteTimeSpanRequired"); + + /// + /// An endpoint must be configured to serve at least one protocol. + /// + internal static string EndPointRequiresAtLeastOneProtocol + { + get => GetString("EndPointRequiresAtLeastOneProtocol"); + } + + /// + /// An endpoint must be configured to serve at least one protocol. + /// + internal static string FormatEndPointRequiresAtLeastOneProtocol() + => GetString("EndPointRequiresAtLeastOneProtocol"); + + /// + /// Using both HTTP/1.x and HTTP/2 on the same endpoint requires the use of TLS. + /// + internal static string EndPointRequiresTlsForHttp1AndHttp2 + { + get => GetString("EndPointRequiresTlsForHttp1AndHttp2"); + } + + /// + /// Using both HTTP/1.x and HTTP/2 on the same endpoint requires the use of TLS. + /// + internal static string FormatEndPointRequiresTlsForHttp1AndHttp2() + => GetString("EndPointRequiresTlsForHttp1AndHttp2"); + + /// + /// HTTP/2 over TLS was not negotiated on an HTTP/2-only endpoint. + /// + internal static string EndPointHttp2NotNegotiated + { + get => GetString("EndPointHttp2NotNegotiated"); + } + + /// + /// HTTP/2 over TLS was not negotiated on an HTTP/2-only endpoint. + /// + internal static string FormatEndPointHttp2NotNegotiated() + => GetString("EndPointHttp2NotNegotiated"); + + /// + /// A dynamic table size of {size} octets is greater than the configured maximum size of {maxSize} octets. + /// + internal static string HPackErrorDynamicTableSizeUpdateTooLarge + { + get => GetString("HPackErrorDynamicTableSizeUpdateTooLarge"); + } + + /// + /// A dynamic table size of {size} octets is greater than the configured maximum size of {maxSize} octets. + /// + internal static string FormatHPackErrorDynamicTableSizeUpdateTooLarge(object size, object maxSize) + => string.Format(CultureInfo.CurrentCulture, GetString("HPackErrorDynamicTableSizeUpdateTooLarge", "size", "maxSize"), size, maxSize); + + /// + /// Index {index} is outside the bounds of the header field table. + /// + internal static string HPackErrorIndexOutOfRange + { + get => GetString("HPackErrorIndexOutOfRange"); + } + + /// + /// Index {index} is outside the bounds of the header field table. + /// + internal static string FormatHPackErrorIndexOutOfRange(object index) + => string.Format(CultureInfo.CurrentCulture, GetString("HPackErrorIndexOutOfRange", "index"), index); + + /// + /// Input data could not be fully decoded. + /// + internal static string HPackHuffmanErrorIncomplete + { + get => GetString("HPackHuffmanErrorIncomplete"); + } + + /// + /// Input data could not be fully decoded. + /// + internal static string FormatHPackHuffmanErrorIncomplete() + => GetString("HPackHuffmanErrorIncomplete"); + + /// + /// Input data contains the EOS symbol. + /// + internal static string HPackHuffmanErrorEOS + { + get => GetString("HPackHuffmanErrorEOS"); + } + + /// + /// Input data contains the EOS symbol. + /// + internal static string FormatHPackHuffmanErrorEOS() + => GetString("HPackHuffmanErrorEOS"); + + /// + /// The destination buffer is not large enough to store the decoded data. + /// + internal static string HPackHuffmanErrorDestinationTooSmall + { + get => GetString("HPackHuffmanErrorDestinationTooSmall"); + } + + /// + /// The destination buffer is not large enough to store the decoded data. + /// + internal static string FormatHPackHuffmanErrorDestinationTooSmall() + => GetString("HPackHuffmanErrorDestinationTooSmall"); + + /// + /// Huffman decoding error. + /// + internal static string HPackHuffmanError + { + get => GetString("HPackHuffmanError"); + } + + /// + /// Huffman decoding error. + /// + internal static string FormatHPackHuffmanError() + => GetString("HPackHuffmanError"); + + /// + /// Decoded string length of {length} octets is greater than the configured maximum length of {maxStringLength} octets. + /// + internal static string HPackStringLengthTooLarge + { + get => GetString("HPackStringLengthTooLarge"); + } + + /// + /// Decoded string length of {length} octets is greater than the configured maximum length of {maxStringLength} octets. + /// + internal static string FormatHPackStringLengthTooLarge(object length, object maxStringLength) + => string.Format(CultureInfo.CurrentCulture, GetString("HPackStringLengthTooLarge", "length", "maxStringLength"), length, maxStringLength); + + /// + /// The header block was incomplete and could not be fully decoded. + /// + internal static string HPackErrorIncompleteHeaderBlock + { + get => GetString("HPackErrorIncompleteHeaderBlock"); + } + + /// + /// The header block was incomplete and could not be fully decoded. + /// + internal static string FormatHPackErrorIncompleteHeaderBlock() + => GetString("HPackErrorIncompleteHeaderBlock"); + + /// + /// The client sent a {frameType} frame with even stream ID {streamId}. + /// + internal static string Http2ErrorStreamIdEven + { + get => GetString("Http2ErrorStreamIdEven"); + } + + /// + /// The client sent a {frameType} frame with even stream ID {streamId}. + /// + internal static string FormatHttp2ErrorStreamIdEven(object frameType, object streamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamIdEven", "frameType", "streamId"), frameType, streamId); + + /// + /// The client sent a A PUSH_PROMISE frame. + /// + internal static string Http2ErrorPushPromiseReceived + { + get => GetString("Http2ErrorPushPromiseReceived"); + } + + /// + /// The client sent a A PUSH_PROMISE frame. + /// + internal static string FormatHttp2ErrorPushPromiseReceived() + => GetString("Http2ErrorPushPromiseReceived"); + + /// + /// The client sent a {frameType} frame to stream ID {streamId} before signaling of the header block for stream ID {headersStreamId}. + /// + internal static string Http2ErrorHeadersInterleaved + { + get => GetString("Http2ErrorHeadersInterleaved"); + } + + /// + /// The client sent a {frameType} frame to stream ID {streamId} before signaling of the header block for stream ID {headersStreamId}. + /// + internal static string FormatHttp2ErrorHeadersInterleaved(object frameType, object streamId, object headersStreamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorHeadersInterleaved", "frameType", "streamId", "headersStreamId"), frameType, streamId, headersStreamId); + + /// + /// The client sent a {frameType} frame with stream ID 0. + /// + internal static string Http2ErrorStreamIdZero + { + get => GetString("Http2ErrorStreamIdZero"); + } + + /// + /// The client sent a {frameType} frame with stream ID 0. + /// + internal static string FormatHttp2ErrorStreamIdZero(object frameType) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamIdZero", "frameType"), frameType); + + /// + /// The client sent a {frameType} frame with stream ID different than 0. + /// + internal static string Http2ErrorStreamIdNotZero + { + get => GetString("Http2ErrorStreamIdNotZero"); + } + + /// + /// The client sent a {frameType} frame with stream ID different than 0. + /// + internal static string FormatHttp2ErrorStreamIdNotZero(object frameType) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamIdNotZero", "frameType"), frameType); + + /// + /// The client sent a {frameType} frame with padding longer than or with the same length as the sent data. + /// + internal static string Http2ErrorPaddingTooLong + { + get => GetString("Http2ErrorPaddingTooLong"); + } + + /// + /// The client sent a {frameType} frame with padding longer than or with the same length as the sent data. + /// + internal static string FormatHttp2ErrorPaddingTooLong(object frameType) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorPaddingTooLong", "frameType"), frameType); + + /// + /// The client sent a {frameType} frame to closed stream ID {streamId}. + /// + internal static string Http2ErrorStreamClosed + { + get => GetString("Http2ErrorStreamClosed"); + } + + /// + /// The client sent a {frameType} frame to closed stream ID {streamId}. + /// + internal static string FormatHttp2ErrorStreamClosed(object frameType, object streamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamClosed", "frameType", "streamId"), frameType, streamId); + + /// + /// The client sent a {frameType} frame to stream ID {streamId} which is in the "half-closed (remote) state". + /// + internal static string Http2ErrorStreamHalfClosedRemote + { + get => GetString("Http2ErrorStreamHalfClosedRemote"); + } + + /// + /// The client sent a {frameType} frame to stream ID {streamId} which is in the "half-closed (remote) state". + /// + internal static string FormatHttp2ErrorStreamHalfClosedRemote(object frameType, object streamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamHalfClosedRemote", "frameType", "streamId"), frameType, streamId); + + /// + /// The client sent a {frameType} frame with dependency information that would cause stream ID {streamId} to depend on itself. + /// + internal static string Http2ErrorStreamSelfDependency + { + get => GetString("Http2ErrorStreamSelfDependency"); + } + + /// + /// The client sent a {frameType} frame with dependency information that would cause stream ID {streamId} to depend on itself. + /// + internal static string FormatHttp2ErrorStreamSelfDependency(object frameType, object streamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamSelfDependency", "frameType", "streamId"), frameType, streamId); + + /// + /// The client sent a {frameType} frame with length different than {expectedLength}. + /// + internal static string Http2ErrorUnexpectedFrameLength + { + get => GetString("Http2ErrorUnexpectedFrameLength"); + } + + /// + /// The client sent a {frameType} frame with length different than {expectedLength}. + /// + internal static string FormatHttp2ErrorUnexpectedFrameLength(object frameType, object expectedLength) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorUnexpectedFrameLength", "frameType", "expectedLength"), frameType, expectedLength); + + /// + /// The client sent a SETTINGS frame with a length that is not a multiple of 6. + /// + internal static string Http2ErrorSettingsLengthNotMultipleOfSix + { + get => GetString("Http2ErrorSettingsLengthNotMultipleOfSix"); + } + + /// + /// The client sent a SETTINGS frame with a length that is not a multiple of 6. + /// + internal static string FormatHttp2ErrorSettingsLengthNotMultipleOfSix() + => GetString("Http2ErrorSettingsLengthNotMultipleOfSix"); + + /// + /// The client sent a SETTINGS frame with ACK set and length different than 0. + /// + internal static string Http2ErrorSettingsAckLengthNotZero + { + get => GetString("Http2ErrorSettingsAckLengthNotZero"); + } + + /// + /// The client sent a SETTINGS frame with ACK set and length different than 0. + /// + internal static string FormatHttp2ErrorSettingsAckLengthNotZero() + => GetString("Http2ErrorSettingsAckLengthNotZero"); + + /// + /// The client sent a SETTINGS frame with a value for parameter {parameter} that is out of range. + /// + internal static string Http2ErrorSettingsParameterOutOfRange + { + get => GetString("Http2ErrorSettingsParameterOutOfRange"); + } + + /// + /// The client sent a SETTINGS frame with a value for parameter {parameter} that is out of range. + /// + internal static string FormatHttp2ErrorSettingsParameterOutOfRange(object parameter) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorSettingsParameterOutOfRange", "parameter"), parameter); + + /// + /// The client sent a WINDOW_UPDATE frame with a window size increment of 0. + /// + internal static string Http2ErrorWindowUpdateIncrementZero + { + get => GetString("Http2ErrorWindowUpdateIncrementZero"); + } + + /// + /// The client sent a WINDOW_UPDATE frame with a window size increment of 0. + /// + internal static string FormatHttp2ErrorWindowUpdateIncrementZero() + => GetString("Http2ErrorWindowUpdateIncrementZero"); + + /// + /// The client sent a CONTINUATION frame not preceded by a HEADERS frame. + /// + internal static string Http2ErrorContinuationWithNoHeaders + { + get => GetString("Http2ErrorContinuationWithNoHeaders"); + } + + /// + /// The client sent a CONTINUATION frame not preceded by a HEADERS frame. + /// + internal static string FormatHttp2ErrorContinuationWithNoHeaders() + => GetString("Http2ErrorContinuationWithNoHeaders"); + + /// + /// The client sent a {frameType} frame to idle stream ID {streamId}. + /// + internal static string Http2ErrorStreamIdle + { + get => GetString("Http2ErrorStreamIdle"); + } + + /// + /// The client sent a {frameType} frame to idle stream ID {streamId}. + /// + internal static string FormatHttp2ErrorStreamIdle(object frameType, object streamId) + => string.Format(CultureInfo.CurrentCulture, GetString("Http2ErrorStreamIdle", "frameType", "streamId"), frameType, streamId); + + /// + /// The client sent trailers containing one or more pseudo-header fields. + /// + internal static string Http2ErrorTrailersContainPseudoHeaderField + { + get => GetString("Http2ErrorTrailersContainPseudoHeaderField"); + } + + /// + /// The client sent trailers containing one or more pseudo-header fields. + /// + internal static string FormatHttp2ErrorTrailersContainPseudoHeaderField() + => GetString("Http2ErrorTrailersContainPseudoHeaderField"); + + /// + /// The client sent a header with uppercase characters in its name. + /// + internal static string Http2ErrorHeaderNameUppercase + { + get => GetString("Http2ErrorHeaderNameUppercase"); + } + + /// + /// The client sent a header with uppercase characters in its name. + /// + internal static string FormatHttp2ErrorHeaderNameUppercase() + => GetString("Http2ErrorHeaderNameUppercase"); + + /// + /// The client sent a trailer with uppercase characters in its name. + /// + internal static string Http2ErrorTrailerNameUppercase + { + get => GetString("Http2ErrorTrailerNameUppercase"); + } + + /// + /// The client sent a trailer with uppercase characters in its name. + /// + internal static string FormatHttp2ErrorTrailerNameUppercase() + => GetString("Http2ErrorTrailerNameUppercase"); + + /// + /// The client sent a HEADERS frame containing trailers without setting the END_STREAM flag. + /// + internal static string Http2ErrorHeadersWithTrailersNoEndStream + { + get => GetString("Http2ErrorHeadersWithTrailersNoEndStream"); + } + + /// + /// The client sent a HEADERS frame containing trailers without setting the END_STREAM flag. + /// + internal static string FormatHttp2ErrorHeadersWithTrailersNoEndStream() + => GetString("Http2ErrorHeadersWithTrailersNoEndStream"); + + /// + /// Request headers missing one or more mandatory pseudo-header fields. + /// + internal static string Http2ErrorMissingMandatoryPseudoHeaderFields + { + get => GetString("Http2ErrorMissingMandatoryPseudoHeaderFields"); + } + + /// + /// Request headers missing one or more mandatory pseudo-header fields. + /// + internal static string FormatHttp2ErrorMissingMandatoryPseudoHeaderFields() + => GetString("Http2ErrorMissingMandatoryPseudoHeaderFields"); + + /// + /// Pseudo-header field found in request headers after regular header fields. + /// + internal static string Http2ErrorPseudoHeaderFieldAfterRegularHeaders + { + get => GetString("Http2ErrorPseudoHeaderFieldAfterRegularHeaders"); + } + + /// + /// Pseudo-header field found in request headers after regular header fields. + /// + internal static string FormatHttp2ErrorPseudoHeaderFieldAfterRegularHeaders() + => GetString("Http2ErrorPseudoHeaderFieldAfterRegularHeaders"); + + /// + /// Request headers contain unknown pseudo-header field. + /// + internal static string Http2ErrorUnknownPseudoHeaderField + { + get => GetString("Http2ErrorUnknownPseudoHeaderField"); + } + + /// + /// Request headers contain unknown pseudo-header field. + /// + internal static string FormatHttp2ErrorUnknownPseudoHeaderField() + => GetString("Http2ErrorUnknownPseudoHeaderField"); + + /// + /// Request headers contain response-specific pseudo-header field. + /// + internal static string Http2ErrorResponsePseudoHeaderField + { + get => GetString("Http2ErrorResponsePseudoHeaderField"); + } + + /// + /// Request headers contain response-specific pseudo-header field. + /// + internal static string FormatHttp2ErrorResponsePseudoHeaderField() + => GetString("Http2ErrorResponsePseudoHeaderField"); + + /// + /// Request headers contain duplicate pseudo-header field. + /// + internal static string Http2ErrorDuplicatePseudoHeaderField + { + get => GetString("Http2ErrorDuplicatePseudoHeaderField"); + } + + /// + /// Request headers contain duplicate pseudo-header field. + /// + internal static string FormatHttp2ErrorDuplicatePseudoHeaderField() + => GetString("Http2ErrorDuplicatePseudoHeaderField"); + + /// + /// Request headers contain connection-specific header field. + /// + internal static string Http2ErrorConnectionSpecificHeaderField + { + get => GetString("Http2ErrorConnectionSpecificHeaderField"); + } + + /// + /// Request headers contain connection-specific header field. + /// + internal static string FormatHttp2ErrorConnectionSpecificHeaderField() + => GetString("Http2ErrorConnectionSpecificHeaderField"); + + /// + /// Unable to configure default https bindings because no IDefaultHttpsProvider service was provided. + /// + internal static string UnableToConfigureHttpsBindings + { + get => GetString("UnableToConfigureHttpsBindings"); + } + + /// + /// Unable to configure default https bindings because no IDefaultHttpsProvider service was provided. + /// + internal static string FormatUnableToConfigureHttpsBindings() + => GetString("UnableToConfigureHttpsBindings"); + + /// + /// Failed to authenticate HTTPS connection. + /// + internal static string AuthenticationFailed + { + get => GetString("AuthenticationFailed"); + } + + /// + /// Failed to authenticate HTTPS connection. + /// + internal static string FormatAuthenticationFailed() + => GetString("AuthenticationFailed"); + + /// + /// Authentication of the HTTPS connection timed out. + /// + internal static string AuthenticationTimedOut + { + get => GetString("AuthenticationTimedOut"); + } + + /// + /// Authentication of the HTTPS connection timed out. + /// + internal static string FormatAuthenticationTimedOut() + => GetString("AuthenticationTimedOut"); + + /// + /// Certificate {thumbprint} cannot be used as an SSL server certificate. It has an Extended Key Usage extension but the usages do not include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + internal static string InvalidServerCertificateEku + { + get => GetString("InvalidServerCertificateEku"); + } + + /// + /// Certificate {thumbprint} cannot be used as an SSL server certificate. It has an Extended Key Usage extension but the usages do not include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + internal static string FormatInvalidServerCertificateEku(object thumbprint) + => string.Format(CultureInfo.CurrentCulture, GetString("InvalidServerCertificateEku", "thumbprint"), thumbprint); + + /// + /// Value must be a positive TimeSpan. + /// + internal static string PositiveTimeSpanRequired1 + { + get => GetString("PositiveTimeSpanRequired1"); + } + + /// + /// Value must be a positive TimeSpan. + /// + internal static string FormatPositiveTimeSpanRequired1() + => GetString("PositiveTimeSpanRequired1"); + + /// + /// The server certificate parameter is required. + /// + internal static string ServerCertificateRequired + { + get => GetString("ServerCertificateRequired"); + } + + /// + /// The server certificate parameter is required. + /// + internal static string FormatServerCertificateRequired() + => GetString("ServerCertificateRequired"); + + /// + /// No listening endpoints were configured. Binding to {address0} and {address1} by default. + /// + internal static string BindingToDefaultAddresses + { + get => GetString("BindingToDefaultAddresses"); + } + + /// + /// No listening endpoints were configured. Binding to {address0} and {address1} by default. + /// + internal static string FormatBindingToDefaultAddresses(object address0, object address1) + => string.Format(CultureInfo.CurrentCulture, GetString("BindingToDefaultAddresses", "address0", "address1"), address0, address1); + + /// + /// The requested certificate {subject} could not be found in {storeLocation}/{storeName} with AllowInvalid setting: {allowInvalid}. + /// + internal static string CertNotFoundInStore + { + get => GetString("CertNotFoundInStore"); + } + + /// + /// The requested certificate {subject} could not be found in {storeLocation}/{storeName} with AllowInvalid setting: {allowInvalid}. + /// + internal static string FormatCertNotFoundInStore(object subject, object storeLocation, object storeName, object allowInvalid) + => string.Format(CultureInfo.CurrentCulture, GetString("CertNotFoundInStore", "subject", "storeLocation", "storeName", "allowInvalid"), subject, storeLocation, storeName, allowInvalid); + + /// + /// The endpoint {endpointName} is missing the required 'Url' parameter. + /// + internal static string EndpointMissingUrl + { + get => GetString("EndpointMissingUrl"); + } + + /// + /// The endpoint {endpointName} is missing the required 'Url' parameter. + /// + internal static string FormatEndpointMissingUrl(object endpointName) + => string.Format(CultureInfo.CurrentCulture, GetString("EndpointMissingUrl", "endpointName"), endpointName); + + /// + /// Unable to configure HTTPS endpoint. No server certificate was specified, and the default developer certificate could not be found. + /// To generate a developer certificate run 'dotnet dev-certs https'. To trust the certificate (Windows and macOS only) run 'dotnet dev-certs https --trust'. + /// For more information on configuring HTTPS see https://go.microsoft.com/fwlink/?linkid=848054. + /// + internal static string NoCertSpecifiedNoDevelopmentCertificateFound + { + get => GetString("NoCertSpecifiedNoDevelopmentCertificateFound"); + } + + /// + /// Unable to configure HTTPS endpoint. No server certificate was specified, and the default developer certificate could not be found. + /// To generate a developer certificate run 'dotnet dev-certs https'. To trust the certificate (Windows and macOS only) run 'dotnet dev-certs https --trust'. + /// For more information on configuring HTTPS see https://go.microsoft.com/fwlink/?linkid=848054. + /// + internal static string FormatNoCertSpecifiedNoDevelopmentCertificateFound() + => GetString("NoCertSpecifiedNoDevelopmentCertificateFound"); + + /// + /// The endpoint {endpointName} specified multiple certificate sources. + /// + internal static string MultipleCertificateSources + { + get => GetString("MultipleCertificateSources"); + } + + /// + /// The endpoint {endpointName} specified multiple certificate sources. + /// + internal static string FormatMultipleCertificateSources(object endpointName) + => string.Format(CultureInfo.CurrentCulture, GetString("MultipleCertificateSources", "endpointName"), endpointName); + + /// + /// HTTP/2 support is experimental, see https://go.microsoft.com/fwlink/?linkid=866785 to enable it. + /// + internal static string Http2NotSupported + { + get => GetString("Http2NotSupported"); + } + + /// + /// HTTP/2 support is experimental, see https://go.microsoft.com/fwlink/?linkid=866785 to enable it. + /// + internal static string FormatHttp2NotSupported() + => GetString("Http2NotSupported"); + + /// + /// Cannot write to the response body, the response has completed. + /// + internal static string WritingToResponseBodyAfterResponseCompleted + { + get => GetString("WritingToResponseBodyAfterResponseCompleted"); + } + + /// + /// Cannot write to the response body, the response has completed. + /// + internal static string FormatWritingToResponseBodyAfterResponseCompleted() + => GetString("WritingToResponseBodyAfterResponseCompleted"); + + /// + /// Reading the request body timed out due to data arriving too slowly. See MinRequestBodyDataRate. + /// + internal static string BadRequest_RequestBodyTimeout + { + get => GetString("BadRequest_RequestBodyTimeout"); + } + + /// + /// Reading the request body timed out due to data arriving too slowly. See MinRequestBodyDataRate. + /// + internal static string FormatBadRequest_RequestBodyTimeout() + => GetString("BadRequest_RequestBodyTimeout"); + + /// + /// The connection was aborted by the application. + /// + internal static string ConnectionAbortedByApplication + { + get => GetString("ConnectionAbortedByApplication"); + } + + /// + /// The connection was aborted by the application. + /// + internal static string FormatConnectionAbortedByApplication() + => GetString("ConnectionAbortedByApplication"); + + /// + /// The connection was aborted because the server is shutting down and request processing didn't complete within the time specified by HostOptions.ShutdownTimeout. + /// + internal static string ConnectionAbortedDuringServerShutdown + { + get => GetString("ConnectionAbortedDuringServerShutdown"); + } + + /// + /// The connection was aborted because the server is shutting down and request processing didn't complete within the time specified by HostOptions.ShutdownTimeout. + /// + internal static string FormatConnectionAbortedDuringServerShutdown() + => GetString("ConnectionAbortedDuringServerShutdown"); + + /// + /// The connection was timed out by the server because the response was not read by the client at the specified minimum data rate. + /// + internal static string ConnectionTimedBecauseResponseMininumDataRateNotSatisfied + { + get => GetString("ConnectionTimedBecauseResponseMininumDataRateNotSatisfied"); + } + + /// + /// The connection was timed out by the server because the response was not read by the client at the specified minimum data rate. + /// + internal static string FormatConnectionTimedBecauseResponseMininumDataRateNotSatisfied() + => GetString("ConnectionTimedBecauseResponseMininumDataRateNotSatisfied"); + + /// + /// The connection was timed out by the server. + /// + internal static string ConnectionTimedOutByServer + { + get => GetString("ConnectionTimedOutByServer"); + } + + /// + /// The connection was timed out by the server. + /// + internal static string FormatConnectionTimedOutByServer() + => GetString("ConnectionTimedOutByServer"); + + private static string GetString(string name, params string[] formatterNames) + { + var value = _resourceManager.GetString(name); + + System.Diagnostics.Debug.Assert(value != null); + + if (formatterNames != null) + { + for (var i = 0; i < formatterNames.Length; i++) + { + value = value.Replace("{" + formatterNames[i] + "}", "{" + i + "}"); + } + } + + return value; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/ServerAddress.cs b/src/Servers/Kestrel/Core/src/ServerAddress.cs new file mode 100644 index 0000000000..0b7f019c85 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/ServerAddress.cs @@ -0,0 +1,152 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Globalization; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + public class ServerAddress + { + public string Host { get; private set; } + public string PathBase { get; private set; } + public int Port { get; internal set; } + public string Scheme { get; private set; } + + public bool IsUnixPipe + { + get + { + return Host.StartsWith(Constants.UnixPipeHostPrefix, StringComparison.Ordinal); + } + } + + public string UnixPipePath + { + get + { + Debug.Assert(IsUnixPipe); + + return Host.Substring(Constants.UnixPipeHostPrefix.Length - 1); + } + } + + public override string ToString() + { + if (IsUnixPipe) + { + return Scheme.ToLowerInvariant() + "://" + Host.ToLowerInvariant(); + } + else + { + return Scheme.ToLowerInvariant() + "://" + Host.ToLowerInvariant() + ":" + Port.ToString(CultureInfo.InvariantCulture); + } + } + + public override int GetHashCode() + { + return ToString().GetHashCode(); + } + + public override bool Equals(object obj) + { + var other = obj as ServerAddress; + if (other == null) + { + return false; + } + return string.Equals(Scheme, other.Scheme, StringComparison.OrdinalIgnoreCase) + && string.Equals(Host, other.Host, StringComparison.OrdinalIgnoreCase) + && Port == other.Port; + } + + public static ServerAddress FromUrl(string url) + { + url = url ?? string.Empty; + + int schemeDelimiterStart = url.IndexOf("://", StringComparison.Ordinal); + if (schemeDelimiterStart < 0) + { + throw new FormatException(CoreStrings.FormatInvalidUrl(url)); + } + int schemeDelimiterEnd = schemeDelimiterStart + "://".Length; + + var isUnixPipe = url.IndexOf(Constants.UnixPipeHostPrefix, schemeDelimiterEnd, StringComparison.Ordinal) == schemeDelimiterEnd; + + int pathDelimiterStart; + int pathDelimiterEnd; + if (!isUnixPipe) + { + pathDelimiterStart = url.IndexOf("/", schemeDelimiterEnd, StringComparison.Ordinal); + pathDelimiterEnd = pathDelimiterStart; + } + else + { + pathDelimiterStart = url.IndexOf(":", schemeDelimiterEnd + Constants.UnixPipeHostPrefix.Length, StringComparison.Ordinal); + pathDelimiterEnd = pathDelimiterStart + ":".Length; + } + + if (pathDelimiterStart < 0) + { + pathDelimiterStart = pathDelimiterEnd = url.Length; + } + + var serverAddress = new ServerAddress(); + serverAddress.Scheme = url.Substring(0, schemeDelimiterStart); + + var hasSpecifiedPort = false; + if (!isUnixPipe) + { + int portDelimiterStart = url.LastIndexOf(":", pathDelimiterStart - 1, pathDelimiterStart - schemeDelimiterEnd, StringComparison.Ordinal); + if (portDelimiterStart >= 0) + { + int portDelimiterEnd = portDelimiterStart + ":".Length; + + string portString = url.Substring(portDelimiterEnd, pathDelimiterStart - portDelimiterEnd); + int portNumber; + if (int.TryParse(portString, NumberStyles.Integer, CultureInfo.InvariantCulture, out portNumber)) + { + hasSpecifiedPort = true; + serverAddress.Host = url.Substring(schemeDelimiterEnd, portDelimiterStart - schemeDelimiterEnd); + serverAddress.Port = portNumber; + } + } + + if (!hasSpecifiedPort) + { + if (string.Equals(serverAddress.Scheme, "http", StringComparison.OrdinalIgnoreCase)) + { + serverAddress.Port = 80; + } + else if (string.Equals(serverAddress.Scheme, "https", StringComparison.OrdinalIgnoreCase)) + { + serverAddress.Port = 443; + } + } + } + + if (!hasSpecifiedPort) + { + serverAddress.Host = url.Substring(schemeDelimiterEnd, pathDelimiterStart - schemeDelimiterEnd); + } + + if (string.IsNullOrEmpty(serverAddress.Host)) + { + throw new FormatException(CoreStrings.FormatInvalidUrl(url)); + } + + if (url[url.Length - 1] == '/') + { + serverAddress.PathBase = url.Substring(pathDelimiterEnd, url.Length - pathDelimiterEnd - 1); + } + else + { + serverAddress.PathBase = url.Substring(pathDelimiterEnd); + } + + return serverAddress; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Systemd/KestrelServerOptionsSystemdExtensions.cs b/src/Servers/Kestrel/Core/src/Systemd/KestrelServerOptionsSystemdExtensions.cs new file mode 100644 index 0000000000..6def39159d --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Systemd/KestrelServerOptionsSystemdExtensions.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Globalization; +using Microsoft.AspNetCore.Server.Kestrel.Core; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class KestrelServerOptionsSystemdExtensions + { + // SD_LISTEN_FDS_START https://www.freedesktop.org/software/systemd/man/sd_listen_fds.html + private const ulong SdListenFdsStart = 3; + private const string ListenPidEnvVar = "LISTEN_PID"; + + /// + /// Open file descriptor (SD_LISTEN_FDS_START) initialized by systemd socket-based activation logic if available. + /// + /// + /// The . + /// + public static KestrelServerOptions UseSystemd(this KestrelServerOptions options) + { + return options.UseSystemd(_ => { }); + } + + /// + /// Open file descriptor (SD_LISTEN_FDS_START) initialized by systemd socket-based activation logic if available. + /// Specify callback to configure endpoint-specific settings. + /// + /// + /// The . + /// + public static KestrelServerOptions UseSystemd(this KestrelServerOptions options, Action configure) + { + if (string.Equals(Process.GetCurrentProcess().Id.ToString(CultureInfo.InvariantCulture), Environment.GetEnvironmentVariable(ListenPidEnvVar), StringComparison.Ordinal)) + { + options.ListenHandle(SdListenFdsStart, configure); + } + + return options; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/baseline.netcore.json b/src/Servers/Kestrel/Core/src/baseline.netcore.json new file mode 100644 index 0000000000..7a583ce888 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/baseline.netcore.json @@ -0,0 +1,1074 @@ +{ + "AssemblyIdentity": "Microsoft.AspNetCore.Server.Kestrel.Core, Version=2.0.2.0, Culture=neutral, PublicKeyToken=adb9793829ddae60", + "Types": [ + { + "Name": "Microsoft.AspNetCore.Hosting.ListenOptionsConnectionLoggingExtensions", + "Visibility": "Public", + "Kind": "Class", + "Abstract": true, + "Static": true, + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "UseConnectionLogging", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseConnectionLogging", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + }, + { + "Name": "loggerName", + "Type": "System.String" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Hosting.KestrelServerOptionsSystemdExtensions", + "Visibility": "Public", + "Kind": "Class", + "Abstract": true, + "Static": true, + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "UseSystemd", + "Parameters": [ + { + "Name": "options", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseSystemd", + "Parameters": [ + { + "Name": "options", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions" + }, + { + "Name": "configure", + "Type": "System.Action" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.BadHttpRequestException", + "Visibility": "Public", + "Kind": "Class", + "Sealed": true, + "BaseType": "System.IO.IOException", + "ImplementedInterfaces": [], + "Members": [], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServer", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [ + "Microsoft.AspNetCore.Hosting.Server.IServer" + ], + "Members": [ + { + "Kind": "Method", + "Name": "get_Features", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Http.Features.IFeatureCollection", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Hosting.Server.IServer", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Options", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "StartAsync", + "Parameters": [ + { + "Name": "application", + "Type": "Microsoft.AspNetCore.Hosting.Server.IHttpApplication" + }, + { + "Name": "cancellationToken", + "Type": "System.Threading.CancellationToken" + } + ], + "ReturnType": "System.Threading.Tasks.Task", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Hosting.Server.IServer", + "Visibility": "Public", + "GenericParameter": [ + { + "ParameterName": "TContext", + "ParameterPosition": 0, + "BaseTypeOrInterfaces": [] + } + ] + }, + { + "Kind": "Method", + "Name": "StopAsync", + "Parameters": [ + { + "Name": "cancellationToken", + "Type": "System.Threading.CancellationToken" + } + ], + "ReturnType": "System.Threading.Tasks.Task", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Hosting.Server.IServer", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Dispose", + "Parameters": [], + "ReturnType": "System.Void", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "System.IDisposable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [ + { + "Name": "options", + "Type": "Microsoft.Extensions.Options.IOptions" + }, + { + "Name": "transportFactory", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportFactory" + }, + { + "Name": "loggerFactory", + "Type": "Microsoft.Extensions.Logging.ILoggerFactory" + } + ], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerLimits", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_MaxResponseBufferSize", + "Parameters": [], + "ReturnType": "System.Nullable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxResponseBufferSize", + "Parameters": [ + { + "Name": "value", + "Type": "System.Nullable" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxRequestBufferSize", + "Parameters": [], + "ReturnType": "System.Nullable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxRequestBufferSize", + "Parameters": [ + { + "Name": "value", + "Type": "System.Nullable" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxRequestLineSize", + "Parameters": [], + "ReturnType": "System.Int32", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxRequestLineSize", + "Parameters": [ + { + "Name": "value", + "Type": "System.Int32" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxRequestHeadersTotalSize", + "Parameters": [], + "ReturnType": "System.Int32", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxRequestHeadersTotalSize", + "Parameters": [ + { + "Name": "value", + "Type": "System.Int32" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxRequestHeaderCount", + "Parameters": [], + "ReturnType": "System.Int32", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxRequestHeaderCount", + "Parameters": [ + { + "Name": "value", + "Type": "System.Int32" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxRequestBodySize", + "Parameters": [], + "ReturnType": "System.Nullable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxRequestBodySize", + "Parameters": [ + { + "Name": "value", + "Type": "System.Nullable" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_KeepAliveTimeout", + "Parameters": [], + "ReturnType": "System.TimeSpan", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_KeepAliveTimeout", + "Parameters": [ + { + "Name": "value", + "Type": "System.TimeSpan" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_RequestHeadersTimeout", + "Parameters": [], + "ReturnType": "System.TimeSpan", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_RequestHeadersTimeout", + "Parameters": [ + { + "Name": "value", + "Type": "System.TimeSpan" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxConcurrentConnections", + "Parameters": [], + "ReturnType": "System.Nullable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxConcurrentConnections", + "Parameters": [ + { + "Name": "value", + "Type": "System.Nullable" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MaxConcurrentUpgradedConnections", + "Parameters": [], + "ReturnType": "System.Nullable", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MaxConcurrentUpgradedConnections", + "Parameters": [ + { + "Name": "value", + "Type": "System.Nullable" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MinRequestBodyDataRate", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MinRequestBodyDataRate", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_MinResponseDataRate", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MinResponseDataRate", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_AddServerHeader", + "Parameters": [], + "ReturnType": "System.Boolean", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_AddServerHeader", + "Parameters": [ + { + "Name": "value", + "Type": "System.Boolean" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_ApplicationSchedulingMode", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.SchedulingMode", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ApplicationSchedulingMode", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.SchedulingMode" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_AllowSynchronousIO", + "Parameters": [], + "ReturnType": "System.Boolean", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_AllowSynchronousIO", + "Parameters": [ + { + "Name": "value", + "Type": "System.Boolean" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_ApplicationServices", + "Parameters": [], + "ReturnType": "System.IServiceProvider", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ApplicationServices", + "Parameters": [ + { + "Name": "value", + "Type": "System.IServiceProvider" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Limits", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerLimits", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Listen", + "Parameters": [ + { + "Name": "address", + "Type": "System.Net.IPAddress" + }, + { + "Name": "port", + "Type": "System.Int32" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Listen", + "Parameters": [ + { + "Name": "address", + "Type": "System.Net.IPAddress" + }, + { + "Name": "port", + "Type": "System.Int32" + }, + { + "Name": "configure", + "Type": "System.Action" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Listen", + "Parameters": [ + { + "Name": "endPoint", + "Type": "System.Net.IPEndPoint" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Listen", + "Parameters": [ + { + "Name": "endPoint", + "Type": "System.Net.IPEndPoint" + }, + { + "Name": "configure", + "Type": "System.Action" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ListenUnixSocket", + "Parameters": [ + { + "Name": "socketPath", + "Type": "System.String" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ListenUnixSocket", + "Parameters": [ + { + "Name": "socketPath", + "Type": "System.String" + }, + { + "Name": "configure", + "Type": "System.Action" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ListenHandle", + "Parameters": [ + { + "Name": "handle", + "Type": "System.UInt64" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ListenHandle", + "Parameters": [ + { + "Name": "handle", + "Type": "System.UInt64" + }, + { + "Name": "configure", + "Type": "System.Action" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [ + "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation" + ], + "Members": [ + { + "Kind": "Method", + "Name": "get_Type", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ListenType", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_HandleType", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.FileHandleType", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_HandleType", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.FileHandleType" + } + ], + "ReturnType": "System.Void", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_IPEndPoint", + "Parameters": [], + "ReturnType": "System.Net.IPEndPoint", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_IPEndPoint", + "Parameters": [ + { + "Name": "value", + "Type": "System.Net.IPEndPoint" + } + ], + "ReturnType": "System.Void", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_SocketPath", + "Parameters": [], + "ReturnType": "System.String", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_FileHandle", + "Parameters": [], + "ReturnType": "System.UInt64", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_KestrelServerOptions", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerOptions", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_NoDelay", + "Parameters": [], + "ReturnType": "System.Boolean", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_NoDelay", + "Parameters": [ + { + "Name": "value", + "Type": "System.Boolean" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_ConnectionAdapters", + "Parameters": [], + "ReturnType": "System.Collections.Generic.List", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ToString", + "Parameters": [], + "ReturnType": "System.String", + "Virtual": true, + "Override": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_BytesPerSecond", + "Parameters": [], + "ReturnType": "System.Double", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_GracePeriod", + "Parameters": [], + "ReturnType": "System.TimeSpan", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [ + { + "Name": "bytesPerSecond", + "Type": "System.Double" + }, + { + "Name": "gracePeriod", + "Type": "System.TimeSpan" + } + ], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.ServerAddress", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_Host", + "Parameters": [], + "ReturnType": "System.String", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_PathBase", + "Parameters": [], + "ReturnType": "System.String", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Port", + "Parameters": [], + "ReturnType": "System.Int32", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Scheme", + "Parameters": [], + "ReturnType": "System.String", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_IsUnixPipe", + "Parameters": [], + "ReturnType": "System.Boolean", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_UnixPipePath", + "Parameters": [], + "ReturnType": "System.String", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "ToString", + "Parameters": [], + "ReturnType": "System.String", + "Virtual": true, + "Override": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "GetHashCode", + "Parameters": [], + "ReturnType": "System.Int32", + "Virtual": true, + "Override": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "Equals", + "Parameters": [ + { + "Name": "obj", + "Type": "System.Object" + } + ], + "ReturnType": "System.Boolean", + "Virtual": true, + "Override": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "FromUrl", + "Parameters": [ + { + "Name": "url", + "Type": "System.String" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ServerAddress", + "Static": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinRequestBodyDataRateFeature", + "Visibility": "Public", + "Kind": "Interface", + "Abstract": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_MinDataRate", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MinDataRate", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate" + } + ], + "ReturnType": "System.Void", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinResponseDataRateFeature", + "Visibility": "Public", + "Kind": "Interface", + "Abstract": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_MinDataRate", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_MinDataRate", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.MinDataRate" + } + ], + "ReturnType": "System.Void", + "GenericParameter": [] + } + ], + "GenericParameters": [] + } + ] +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/test/AddressBinderTests.cs b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs new file mode 100644 index 0000000000..426e490502 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs @@ -0,0 +1,154 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class AddressBinderTests + { + [Theory] + [InlineData("http://10.10.10.10:5000/", "10.10.10.10", 5000)] + [InlineData("http://[::1]:5000", "::1", 5000)] + [InlineData("http://[::1]", "::1", 80)] + [InlineData("http://127.0.0.1", "127.0.0.1", 80)] + [InlineData("https://127.0.0.1", "127.0.0.1", 443)] + public void CorrectIPEndpointsAreCreated(string address, string expectedAddress, int expectedPort) + { + Assert.True(AddressBinder.TryCreateIPEndPoint( + ServerAddress.FromUrl(address), out var endpoint)); + Assert.NotNull(endpoint); + Assert.Equal(IPAddress.Parse(expectedAddress), endpoint.Address); + Assert.Equal(expectedPort, endpoint.Port); + } + + [Theory] + [InlineData("http://*")] + [InlineData("http://*:5000")] + [InlineData("http://+:80")] + [InlineData("http://+")] + [InlineData("http://randomhost:6000")] + [InlineData("http://randomhost")] + [InlineData("https://randomhost")] + public void DoesNotCreateIPEndPointOnInvalidIPAddress(string address) + { + Assert.False(AddressBinder.TryCreateIPEndPoint( + ServerAddress.FromUrl(address), out var endpoint)); + } + + [Theory] + [InlineData("*")] + [InlineData("randomhost")] + [InlineData("+")] + [InlineData("contoso.com")] + public void ParseAddressDefaultsToAnyIPOnInvalidIPAddress(string host) + { + var options = new KestrelServerOptions(); + var listenOptions = AddressBinder.ParseAddress($"http://{host}", out var https); + Assert.IsType(listenOptions); + Assert.Equal(ListenType.IPEndPoint, listenOptions.Type); + Assert.Equal(IPAddress.IPv6Any, listenOptions.IPEndPoint.Address); + Assert.Equal(80, listenOptions.IPEndPoint.Port); + Assert.False(https); + } + + [Fact] + public void ParseAddressLocalhost() + { + var options = new KestrelServerOptions(); + var listenOptions = AddressBinder.ParseAddress("http://localhost", out var https); + Assert.IsType(listenOptions); + Assert.Equal(ListenType.IPEndPoint, listenOptions.Type); + Assert.Equal(IPAddress.Loopback, listenOptions.IPEndPoint.Address); + Assert.Equal(80, listenOptions.IPEndPoint.Port); + Assert.False(https); + } + + [Fact] + public void ParseAddressUnixPipe() + { + var options = new KestrelServerOptions(); + var listenOptions = AddressBinder.ParseAddress("http://unix:/tmp/kestrel-test.sock", out var https); + Assert.Equal(ListenType.SocketPath, listenOptions.Type); + Assert.Equal("/tmp/kestrel-test.sock", listenOptions.SocketPath); + Assert.False(https); + } + + [Theory] + [InlineData("http://10.10.10.10:5000/", "10.10.10.10", 5000, false)] + [InlineData("http://[::1]:5000", "::1", 5000, false)] + [InlineData("http://[::1]", "::1", 80, false)] + [InlineData("http://127.0.0.1", "127.0.0.1", 80, false)] + [InlineData("https://127.0.0.1", "127.0.0.1", 443, true)] + public void ParseAddressIP(string address, string ip, int port, bool isHttps) + { + var options = new KestrelServerOptions(); + var listenOptions = AddressBinder.ParseAddress(address, out var https); + Assert.Equal(ListenType.IPEndPoint, listenOptions.Type); + Assert.Equal(IPAddress.Parse(ip), listenOptions.IPEndPoint.Address); + Assert.Equal(port, listenOptions.IPEndPoint.Port); + Assert.Equal(isHttps, https); + } + + [Fact] + public async Task WrapsAddressInUseExceptionAsIOException() + { + var addresses = new ServerAddressesFeature(); + addresses.Addresses.Add("http://localhost:5000"); + var options = new KestrelServerOptions(); + + await Assert.ThrowsAsync(() => + AddressBinder.BindAsync(addresses, + options, + NullLogger.Instance, + endpoint => throw new AddressInUseException("already in use"))); + } + + [Theory] + [InlineData("http://*:80")] + [InlineData("http://+:80")] + [InlineData("http://contoso.com:80")] + public async Task FallbackToIPv4WhenIPv6AnyBindFails(string address) + { + var logger = new MockLogger(); + var addresses = new ServerAddressesFeature(); + addresses.Addresses.Add(address); + var options = new KestrelServerOptions(); + + var ipV6Attempt = false; + var ipV4Attempt = false; + + await AddressBinder.BindAsync(addresses, + options, + logger, + endpoint => + { + if (endpoint.IPEndPoint.Address == IPAddress.IPv6Any) + { + ipV6Attempt = true; + throw new InvalidOperationException("EAFNOSUPPORT"); + } + + if (endpoint.IPEndPoint.Address == IPAddress.Any) + { + ipV4Attempt = true; + } + + return Task.CompletedTask; + }); + + Assert.True(ipV4Attempt, "Should have attempted to bind to IPAddress.Any"); + Assert.True(ipV6Attempt, "Should have attempted to bind to IPAddress.IPv6Any"); + Assert.Contains(logger.Messages, f => f.Equals(CoreStrings.FormatFallbackToIPv4Any(80))); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/AsciiDecoding.cs b/src/Servers/Kestrel/Core/test/AsciiDecoding.cs new file mode 100644 index 0000000000..7fa45513d2 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/AsciiDecoding.cs @@ -0,0 +1,76 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Numerics; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class AsciiDecodingTests + { + [Fact] + private void FullAsciiRangeSupported() + { + var byteRange = Enumerable.Range(1, 127).Select(x => (byte)x); + + var byteArray = byteRange + .Concat(byteRange) + .Concat(byteRange) + .Concat(byteRange) + .Concat(byteRange) + .Concat(byteRange) + .ToArray(); + + var s = new Span(byteArray).GetAsciiStringNonNullCharacters(); + + Assert.Equal(s.Length, byteArray.Length); + + for (var i = 1; i < byteArray.Length; i++) + { + var sb = (byte)s[i]; + var b = byteArray[i]; + + Assert.Equal(sb, b); + } + } + + [Theory] + [InlineData(0x00)] + [InlineData(0x80)] + private void ExceptionThrownForZeroOrNonAscii(byte b) + { + for (var length = 1; length < Vector.Count * 4; length++) + { + for (var position = 0; position < length; position++) + { + var byteRange = Enumerable.Range(1, length).Select(x => (byte)x).ToArray(); + byteRange[position] = b; + + Assert.Throws(() => new Span(byteRange).GetAsciiStringNonNullCharacters()); + } + } + } + + [Fact] + private void LargeAllocationProducesCorrectResults() + { + var byteRange = Enumerable.Range(0, 16384 + 64).Select(x => (byte)((x & 0x7f) | 0x01)).ToArray(); + var expectedByteRange = byteRange.Concat(byteRange).ToArray(); + + var s = new Span(expectedByteRange).GetAsciiStringNonNullCharacters(); + + Assert.Equal(expectedByteRange.Length, s.Length); + + for (var i = 0; i < expectedByteRange.Length; i++) + { + var sb = (byte)((s[i] & 0x7f) | 0x01); + var b = expectedByteRange[i]; + + Assert.Equal(sb, b); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/BufferReaderTests.cs b/src/Servers/Kestrel/Core/test/BufferReaderTests.cs new file mode 100644 index 0000000000..294f61cdb0 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/BufferReaderTests.cs @@ -0,0 +1,300 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Xunit; + +namespace System.Buffers.Tests +{ + public abstract class ReadableBufferReaderFacts + { + public class Array : SingleSegment + { + public Array() : base(ReadOnlySequenceFactory.ArrayFactory) { } + internal Array(ReadOnlySequenceFactory factory) : base(factory) { } + } + + public class OwnedMemory : SingleSegment + { + public OwnedMemory() : base(ReadOnlySequenceFactory.OwnedMemoryFactory) { } + } + public class Memory : SingleSegment + { + public Memory() : base(ReadOnlySequenceFactory.MemoryFactory) { } + } + + public class SingleSegment : SegmentPerByte + { + public SingleSegment() : base(ReadOnlySequenceFactory.SingleSegmentFactory) { } + internal SingleSegment(ReadOnlySequenceFactory factory) : base(factory) { } + + [Fact] + public void AdvanceSingleBufferSkipsBytes() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2, 3, 4, 5 })); + reader.Advance(2); + Assert.Equal(2, reader.CurrentSegmentIndex); + Assert.Equal(3, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(3, reader.Peek()); + reader.Advance(2); + Assert.Equal(5, reader.Peek()); + Assert.Equal(4, reader.CurrentSegmentIndex); + Assert.Equal(5, reader.CurrentSegment[reader.CurrentSegmentIndex]); + } + + [Fact] + public void TakeReturnsByteAndMoves() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2 })); + Assert.Equal(0, reader.CurrentSegmentIndex); + Assert.Equal(1, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(1, reader.Read()); + Assert.Equal(1, reader.CurrentSegmentIndex); + Assert.Equal(2, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(2, reader.Read()); + Assert.Equal(-1, reader.Read()); + } + } + + public class SegmentPerByte : ReadableBufferReaderFacts + { + public SegmentPerByte() : base(ReadOnlySequenceFactory.SegmentPerByteFactory) { } + internal SegmentPerByte(ReadOnlySequenceFactory factory) : base(factory) { } + } + + internal ReadOnlySequenceFactory Factory { get; } + + internal ReadableBufferReaderFacts(ReadOnlySequenceFactory factory) + { + Factory = factory; + } + + [Fact] + public void PeekReturnsByteWithoutMoving() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2 })); + Assert.Equal(1, reader.Peek()); + Assert.Equal(1, reader.Peek()); + } + + [Fact] + public void CursorIsCorrectAtEnd() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2 })); + reader.Read(); + reader.Read(); + Assert.True(reader.End); + } + + [Fact] + public void CursorIsCorrectWithEmptyLastBlock() + { + var first = new BufferSegment(new byte[] { 1, 2 }); + var last = first.Append(new byte[4]); + + var reader = new BufferReader(new ReadOnlySequence(first, 0, last, 0)); + reader.Read(); + reader.Read(); + reader.Read(); + Assert.Same(last, reader.Position.GetObject()); + Assert.Equal(0, reader.Position.GetInteger()); + Assert.True(reader.End); + } + + [Fact] + public void PeekReturnsMinuOneByteInTheEnd() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2 })); + Assert.Equal(1, reader.Read()); + Assert.Equal(2, reader.Read()); + Assert.Equal(-1, reader.Peek()); + } + + [Fact] + public void AdvanceToEndThenPeekReturnsMinusOne() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2, 3, 4, 5 })); + reader.Advance(5); + Assert.True(reader.End); + Assert.Equal(-1, reader.Peek()); + } + + [Fact] + public void AdvancingPastLengthThrows() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { 1, 2, 3, 4, 5 })); + try + { + reader.Advance(6); + Assert.True(false); + } + catch (Exception ex) + { + Assert.True(ex is ArgumentOutOfRangeException); + } + } + + [Fact] + public void CtorFindsFirstNonEmptySegment() + { + var buffer = Factory.CreateWithContent(new byte[] { 1 }); + var reader = new BufferReader(buffer); + + Assert.Equal(1, reader.Peek()); + } + + [Fact] + public void EmptySegmentsAreSkippedOnMoveNext() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2 }); + var reader = new BufferReader(buffer); + + Assert.Equal(1, reader.Peek()); + reader.Advance(1); + Assert.Equal(2, reader.Peek()); + } + + [Fact] + public void PeekGoesToEndIfAllEmptySegments() + { + var buffer = Factory.CreateOfSize(0); + var reader = new BufferReader(buffer); + + Assert.Equal(-1, reader.Peek()); + Assert.True(reader.End); + } + + [Fact] + public void AdvanceTraversesSegments() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2, 3 }); + var reader = new BufferReader(buffer); + + reader.Advance(2); + Assert.Equal(3, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(3, reader.Read()); + } + + [Fact] + public void AdvanceThrowsPastLengthMultipleSegments() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2, 3 }); + var reader = new BufferReader(buffer); + + try + { + reader.Advance(4); + Assert.True(false); + } + catch (Exception ex) + { + Assert.True(ex is ArgumentOutOfRangeException); + } + } + + [Fact] + public void TakeTraversesSegments() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2, 3 }); + var reader = new BufferReader(buffer); + + Assert.Equal(1, reader.Read()); + Assert.Equal(2, reader.Read()); + Assert.Equal(3, reader.Read()); + Assert.Equal(-1, reader.Read()); + } + + [Fact] + public void PeekTraversesSegments() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2 }); + var reader = new BufferReader(buffer); + + Assert.Equal(1, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(1, reader.Read()); + + Assert.Equal(2, reader.CurrentSegment[reader.CurrentSegmentIndex]); + Assert.Equal(2, reader.Peek()); + Assert.Equal(2, reader.Read()); + Assert.Equal(-1, reader.Peek()); + Assert.Equal(-1, reader.Read()); + } + + [Fact] + public void PeekWorkesWithEmptySegments() + { + var buffer = Factory.CreateWithContent(new byte[] { 1 }); + var reader = new BufferReader(buffer); + + Assert.Equal(0, reader.CurrentSegmentIndex); + Assert.Equal(1, reader.CurrentSegment.Length); + Assert.Equal(1, reader.Peek()); + Assert.Equal(1, reader.Read()); + Assert.Equal(-1, reader.Peek()); + Assert.Equal(-1, reader.Read()); + } + + [Fact] + public void WorkesWithEmptyBuffer() + { + var reader = new BufferReader(Factory.CreateWithContent(new byte[] { })); + + Assert.Equal(0, reader.CurrentSegmentIndex); + Assert.Equal(0, reader.CurrentSegment.Length); + Assert.Equal(-1, reader.Peek()); + Assert.Equal(-1, reader.Read()); + } + + [Theory] + [InlineData(0, false)] + [InlineData(5, false)] + [InlineData(10, false)] + [InlineData(11, true)] + [InlineData(12, true)] + [InlineData(15, true)] + public void ReturnsCorrectCursor(int takes, bool end) + { + var readableBuffer = Factory.CreateWithContent(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); + var reader = new BufferReader(readableBuffer); + for (int i = 0; i < takes; i++) + { + reader.Read(); + } + + var expected = end ? new byte[] { } : readableBuffer.Slice((long)takes).ToArray(); + Assert.Equal(expected, readableBuffer.Slice(reader.Position).ToArray()); + } + + [Fact] + public void SlicingBufferReturnsCorrectCursor() + { + var buffer = Factory.CreateWithContent(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); + var sliced = buffer.Slice(2L); + + var reader = new BufferReader(sliced); + Assert.Equal(sliced.ToArray(), buffer.Slice(reader.Position).ToArray()); + Assert.Equal(2, reader.Peek()); + Assert.Equal(0, reader.CurrentSegmentIndex); + } + + [Fact] + public void ReaderIndexIsCorrect() + { + var buffer = Factory.CreateWithContent(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var reader = new BufferReader(buffer); + + var counter = 1; + while (!reader.End) + { + var span = reader.CurrentSegment; + for (int i = reader.CurrentSegmentIndex; i < span.Length; i++) + { + Assert.Equal(counter++, reader.CurrentSegment[i]); + } + reader.Advance(span.Length); + } + Assert.Equal(buffer.Length, reader.ConsumedBytes); + } + } + +} diff --git a/src/Servers/Kestrel/Core/test/BufferWriterTests.cs b/src/Servers/Kestrel/Core/test/BufferWriterTests.cs new file mode 100644 index 0000000000..2060ccdb32 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/BufferWriterTests.cs @@ -0,0 +1,201 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +namespace System.IO.Pipelines.Tests +{ + public class BufferWriterTests : IDisposable + { + protected Pipe Pipe; + public BufferWriterTests() + { + Pipe = new Pipe(new PipeOptions(useSynchronizationContext: false, pauseWriterThreshold: 0, resumeWriterThreshold: 0)); + } + + public void Dispose() + { + Pipe.Writer.Complete(); + Pipe.Reader.Complete(); + } + + private byte[] Read() + { + Pipe.Writer.FlushAsync().GetAwaiter().GetResult(); + Pipe.Writer.Complete(); + ReadResult readResult = Pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + byte[] data = readResult.Buffer.ToArray(); + Pipe.Reader.AdvanceTo(readResult.Buffer.End); + return data; + } + + [Theory] + [InlineData(3, -1, 0)] + [InlineData(3, 0, -1)] + [InlineData(3, 0, 4)] + [InlineData(3, 4, 0)] + [InlineData(3, -1, -1)] + [InlineData(3, 4, 4)] + public void ThrowsForInvalidParameters(int arrayLength, int offset, int length) + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + var array = new byte[arrayLength]; + for (var i = 0; i < array.Length; i++) + { + array[i] = (byte)(i + 1); + } + + writer.Write(new Span(array, 0, 0)); + writer.Write(new Span(array, array.Length, 0)); + + try + { + writer.Write(new Span(array, offset, length)); + Assert.True(false); + } + catch (Exception ex) + { + Assert.True(ex is ArgumentOutOfRangeException); + } + + writer.Write(new Span(array, 0, array.Length)); + writer.Commit(); + + Assert.Equal(array, Read()); + } + + [Theory] + [InlineData(0, 3)] + [InlineData(1, 2)] + [InlineData(2, 1)] + [InlineData(1, 1)] + public void CanWriteWithOffsetAndLength(int offset, int length) + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + var array = new byte[] { 1, 2, 3 }; + + writer.Write(new Span(array, offset, length)); + + Assert.Equal(0, writer.BytesCommitted); + + writer.Commit(); + + Assert.Equal(length, writer.BytesCommitted); + Assert.Equal(array.Skip(offset).Take(length).ToArray(), Read()); + Assert.Equal(length, writer.BytesCommitted); + } + + [Fact] + public void CanWriteEmpty() + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + var array = new byte[] { }; + + writer.Write(array); + writer.Write(new Span(array, 0, array.Length)); + writer.Commit(); + + Assert.Equal(0, writer.BytesCommitted); + Assert.Equal(array, Read()); + } + + [Fact] + public void CanWriteIntoHeadlessBuffer() + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + + writer.Write(new byte[] { 1, 2, 3 }); + writer.Commit(); + + Assert.Equal(3, writer.BytesCommitted); + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Fact] + public void CanWriteMultipleTimes() + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + + writer.Write(new byte[] { 1 }); + writer.Write(new byte[] { 2 }); + writer.Write(new byte[] { 3 }); + writer.Commit(); + + Assert.Equal(3, writer.BytesCommitted); + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Fact] + public void CanWriteOverTheBlockLength() + { + Memory memory = Pipe.Writer.GetMemory(); + BufferWriter writer = new BufferWriter(Pipe.Writer); + + IEnumerable source = Enumerable.Range(0, memory.Length).Select(i => (byte)i); + byte[] expectedBytes = source.Concat(source).Concat(source).ToArray(); + + writer.Write(expectedBytes); + writer.Commit(); + + Assert.Equal(expectedBytes.LongLength, writer.BytesCommitted); + Assert.Equal(expectedBytes, Read()); + } + + [Fact] + public void EnsureAllocatesSpan() + { + BufferWriter writer = new BufferWriter(Pipe.Writer); + writer.Ensure(10); + Assert.True(writer.Span.Length > 10); + Assert.Equal(0, writer.BytesCommitted); + Assert.Equal(new byte[] { }, Read()); + } + + [Fact] + public void ExposesSpan() + { + int initialLength = Pipe.Writer.GetMemory().Length; + BufferWriter writer = new BufferWriter(Pipe.Writer); + Assert.Equal(initialLength, writer.Span.Length); + Assert.Equal(new byte[] { }, Read()); + } + + [Fact] + public void SlicesSpanAndAdvancesAfterWrite() + { + int initialLength = Pipe.Writer.GetMemory().Length; + + BufferWriter writer = new BufferWriter(Pipe.Writer); + + writer.Write(new byte[] { 1, 2, 3 }); + writer.Commit(); + + Assert.Equal(3, writer.BytesCommitted); + Assert.Equal(initialLength - 3, writer.Span.Length); + Assert.Equal(Pipe.Writer.GetMemory().Length, writer.Span.Length); + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Theory] + [InlineData(5)] + [InlineData(50)] + [InlineData(500)] + [InlineData(5000)] + [InlineData(50000)] + public void WriteLargeDataBinary(int length) + { + var data = new byte[length]; + new Random(length).NextBytes(data); + + BufferWriter writer = new BufferWriter(Pipe.Writer); + writer.Write(data); + writer.Commit(); + + Assert.Equal(length, writer.BytesCommitted); + Assert.Equal(data, Read()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/ChunkWriterTests.cs b/src/Servers/Kestrel/Core/test/ChunkWriterTests.cs new file mode 100644 index 0000000000..722c0281ab --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ChunkWriterTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ChunkWriterTests + { + [Theory] + [InlineData(1, "1\r\n")] + [InlineData(10, "a\r\n")] + [InlineData(0x08, "8\r\n")] + [InlineData(0x10, "10\r\n")] + [InlineData(0x080, "80\r\n")] + [InlineData(0x100, "100\r\n")] + [InlineData(0x0800, "800\r\n")] + [InlineData(0x1000, "1000\r\n")] + [InlineData(0x08000, "8000\r\n")] + [InlineData(0x10000, "10000\r\n")] + [InlineData(0x080000, "80000\r\n")] + [InlineData(0x100000, "100000\r\n")] + [InlineData(0x0800000, "800000\r\n")] + [InlineData(0x1000000, "1000000\r\n")] + [InlineData(0x08000000, "8000000\r\n")] + [InlineData(0x10000000, "10000000\r\n")] + [InlineData(0x7fffffffL, "7fffffff\r\n")] + public void ChunkedPrefixMustBeHexCrLfWithoutLeadingZeros(int dataCount, string expected) + { + var beginChunkBytes = ChunkWriter.BeginChunkBytes(dataCount); + + Assert.Equal(Encoding.ASCII.GetBytes(expected), beginChunkBytes.ToArray()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs b/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs new file mode 100644 index 0000000000..c9acd1c658 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ConnectionDispatcherTests + { + [Fact] + public void OnConnectionCreatesLogScopeWithConnectionId() + { + var serviceContext = new TestServiceContext(); + var tcs = new TaskCompletionSource(); + var dispatcher = new ConnectionDispatcher(serviceContext, _ => tcs.Task); + + var connection = Mock.Of(); + + dispatcher.OnConnection(connection); + + // The scope should be created + var scopeObjects = ((TestKestrelTrace)serviceContext.Log) + .Logger + .Scopes + .OfType>>() + .ToList(); + + Assert.Single(scopeObjects); + var pairs = scopeObjects[0].ToDictionary(p => p.Key, p => p.Value); + Assert.True(pairs.ContainsKey("ConnectionId")); + Assert.Equal(connection.ConnectionId, pairs["ConnectionId"]); + + tcs.TrySetResult(null); + + // Verify the scope was disposed after request processing completed + Assert.True(((TestKestrelTrace)serviceContext.Log).Logger.Scopes.IsEmpty); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/DateHeaderValueManagerTests.cs b/src/Servers/Kestrel/Core/test/DateHeaderValueManagerTests.cs new file mode 100644 index 0000000000..3d1e880765 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/DateHeaderValueManagerTests.cs @@ -0,0 +1,113 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class DateHeaderValueManagerTests + { + /// + /// DateTime format string for RFC1123. + /// + /// + /// See https://msdn.microsoft.com/en-us/library/az4se3k1(v=vs.110).aspx#RFC1123 for info on the format. + /// + private const string Rfc1123DateFormat = "r"; + + [Fact] + public void GetDateHeaderValue_ReturnsDateValueInRFC1123Format() + { + var now = DateTimeOffset.UtcNow; + var systemClock = new MockSystemClock + { + UtcNow = now + }; + + var dateHeaderValueManager = new DateHeaderValueManager(systemClock); + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + } + + [Fact] + public void GetDateHeaderValue_ReturnsCachedValueBetweenTimerTicks() + { + var now = DateTimeOffset.UtcNow; + var future = now.AddSeconds(10); + var systemClock = new MockSystemClock + { + UtcNow = now + }; + + var dateHeaderValueManager = new DateHeaderValueManager(systemClock); + var testKestrelTrace = new TestKestrelTrace(); + + using (var heartbeat = new Heartbeat(new IHeartbeatHandler[] { dateHeaderValueManager }, systemClock, DebuggerWrapper.Singleton, testKestrelTrace)) + { + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + systemClock.UtcNow = future; + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + } + + Assert.Equal(1, systemClock.UtcNowCalled); + } + + [Fact] + public void GetDateHeaderValue_ReturnsUpdatedValueAfterHeartbeat() + { + var now = DateTimeOffset.UtcNow; + var future = now.AddSeconds(10); + var systemClock = new MockSystemClock + { + UtcNow = now + }; + + var dateHeaderValueManager = new DateHeaderValueManager(systemClock); + var testKestrelTrace = new TestKestrelTrace(); + + var mockHeartbeatHandler = new Mock(); + + using (var heartbeat = new Heartbeat(new[] { dateHeaderValueManager, mockHeartbeatHandler.Object }, systemClock, DebuggerWrapper.Singleton, testKestrelTrace)) + { + heartbeat.OnHeartbeat(); + + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + + // Wait for the next heartbeat before verifying GetDateHeaderValues picks up new time. + systemClock.UtcNow = future; + + heartbeat.OnHeartbeat(); + + Assert.Equal(future.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + Assert.True(systemClock.UtcNowCalled >= 2); + } + } + + [Fact] + public void GetDateHeaderValue_ReturnsLastDateValueAfterHeartbeatDisposed() + { + var now = DateTimeOffset.UtcNow; + var future = now.AddSeconds(10); + var systemClock = new MockSystemClock + { + UtcNow = now + }; + + var dateHeaderValueManager = new DateHeaderValueManager(systemClock); + var testKestrelTrace = new TestKestrelTrace(); + + using (var heatbeat = new Heartbeat(new IHeartbeatHandler[] { dateHeaderValueManager }, systemClock, DebuggerWrapper.Singleton, testKestrelTrace)) + { + heatbeat.OnHeartbeat(); + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + } + + systemClock.UtcNow = future; + Assert.Equal(now.ToString(Rfc1123DateFormat), dateHeaderValueManager.GetDateHeaderValues().String); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/DynamicTableTests.cs b/src/Servers/Kestrel/Core/test/DynamicTableTests.cs new file mode 100644 index 0000000000..0943272c72 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/DynamicTableTests.cs @@ -0,0 +1,158 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class DynamicTableTests + { + private readonly HeaderField _header1 = new HeaderField(Encoding.ASCII.GetBytes("header-1"), Encoding.ASCII.GetBytes("value1")); + private readonly HeaderField _header2 = new HeaderField(Encoding.ASCII.GetBytes("header-02"), Encoding.ASCII.GetBytes("value_2")); + + [Fact] + public void DynamicTableIsInitiallyEmpty() + { + var dynamicTable = new DynamicTable(4096); + Assert.Equal(0, dynamicTable.Count); + Assert.Equal(0, dynamicTable.Size); + Assert.Equal(4096, dynamicTable.MaxSize); + } + + [Fact] + public void CountIsNumberOfEntriesInDynamicTable() + { + var dynamicTable = new DynamicTable(4096); + + dynamicTable.Insert(_header1.Name, _header1.Value); + Assert.Equal(1, dynamicTable.Count); + + dynamicTable.Insert(_header2.Name, _header2.Value); + Assert.Equal(2, dynamicTable.Count); + } + + [Fact] + public void SizeIsCurrentDynamicTableSize() + { + var dynamicTable = new DynamicTable(4096); + Assert.Equal(0, dynamicTable.Size); + + dynamicTable.Insert(_header1.Name, _header1.Value); + Assert.Equal(_header1.Length, dynamicTable.Size); + + dynamicTable.Insert(_header2.Name, _header2.Value); + Assert.Equal(_header1.Length + _header2.Length, dynamicTable.Size); + } + + [Fact] + public void FirstEntryIsMostRecentEntry() + { + var dynamicTable = new DynamicTable(4096); + dynamicTable.Insert(_header1.Name, _header1.Value); + dynamicTable.Insert(_header2.Name, _header2.Value); + + VerifyTableEntries(dynamicTable, _header2, _header1); + } + + [Fact] + public void ThrowsIndexOutOfRangeException() + { + var dynamicTable = new DynamicTable(4096); + Assert.Throws(() => dynamicTable[0]); + + dynamicTable.Insert(_header1.Name, _header1.Value); + Assert.Throws(() => dynamicTable[1]); + } + + [Fact] + public void NoOpWhenInsertingEntryLargerThanMaxSize() + { + var dynamicTable = new DynamicTable(_header1.Length - 1); + dynamicTable.Insert(_header1.Name, _header1.Value); + + Assert.Equal(0, dynamicTable.Count); + Assert.Equal(0, dynamicTable.Size); + } + + [Fact] + public void NoOpWhenInsertingEntryLargerThanRemainingSpace() + { + var dynamicTable = new DynamicTable(_header1.Length); + dynamicTable.Insert(_header1.Name, _header1.Value); + + VerifyTableEntries(dynamicTable, _header1); + + dynamicTable.Insert(_header2.Name, _header2.Value); + + Assert.Equal(0, dynamicTable.Count); + Assert.Equal(0, dynamicTable.Size); + } + + [Fact] + public void ResizingEvictsOldestEntries() + { + var dynamicTable = new DynamicTable(4096); + dynamicTable.Insert(_header1.Name, _header1.Value); + dynamicTable.Insert(_header2.Name, _header2.Value); + + VerifyTableEntries(dynamicTable, _header2, _header1); + + dynamicTable.Resize(_header2.Length); + + VerifyTableEntries(dynamicTable, _header2); + } + + [Fact] + public void ResizingToZeroEvictsAllEntries() + { + var dynamicTable = new DynamicTable(4096); + dynamicTable.Insert(_header1.Name, _header1.Value); + dynamicTable.Insert(_header2.Name, _header2.Value); + + dynamicTable.Resize(0); + + Assert.Equal(0, dynamicTable.Count); + Assert.Equal(0, dynamicTable.Size); + } + + [Fact] + public void CanBeResizedToLargerMaxSize() + { + var dynamicTable = new DynamicTable(_header1.Length); + dynamicTable.Insert(_header1.Name, _header1.Value); + dynamicTable.Insert(_header2.Name, _header2.Value); + + // _header2 is larger than _header1, so an attempt at inserting it + // would first clear the table then return without actually inserting it, + // given it is larger than the current max size. + Assert.Equal(0, dynamicTable.Count); + Assert.Equal(0, dynamicTable.Size); + + dynamicTable.Resize(dynamicTable.MaxSize + _header2.Length); + dynamicTable.Insert(_header2.Name, _header2.Value); + + VerifyTableEntries(dynamicTable, _header2); + } + + private void VerifyTableEntries(DynamicTable dynamicTable, params HeaderField[] entries) + { + Assert.Equal(entries.Length, dynamicTable.Count); + Assert.Equal(entries.Sum(e => e.Length), dynamicTable.Size); + + for (var i = 0; i < entries.Length; i++) + { + var headerField = dynamicTable[i]; + + Assert.NotSame(entries[i].Name, headerField.Name); + Assert.Equal(entries[i].Name, headerField.Name); + + Assert.NotSame(entries[i].Value, headerField.Value); + Assert.Equal(entries[i].Value, headerField.Value); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HPackDecoderTests.cs b/src/Servers/Kestrel/Core/test/HPackDecoderTests.cs new file mode 100644 index 0000000000..a20c0be120 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HPackDecoderTests.cs @@ -0,0 +1,560 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HPackDecoderTests : IHttpHeadersHandler + { + private const int DynamicTableInitialMaxSize = 4096; + + // Indexed Header Field Representation - Static Table - Index 2 (:method: GET) + private static readonly byte[] _indexedHeaderStatic = new byte[] { 0x82 }; + + // Indexed Header Field Representation - Dynamic Table - Index 62 (first index in dynamic table) + private static readonly byte[] _indexedHeaderDynamic = new byte[] { 0xbe }; + + // Literal Header Field with Incremental Indexing Representation - New Name + private static readonly byte[] _literalHeaderFieldWithIndexingNewName = new byte[] { 0x40 }; + + // Literal Header Field with Incremental Indexing Representation - Indexed Name - Index 58 (user-agent) + private static readonly byte[] _literalHeaderFieldWithIndexingIndexedName = new byte[] { 0x7a }; + + // Literal Header Field without Indexing Representation - New Name + private static readonly byte[] _literalHeaderFieldWithoutIndexingNewName = new byte[] { 0x00 }; + + // Literal Header Field without Indexing Representation - Indexed Name - Index 58 (user-agent) + private static readonly byte[] _literalHeaderFieldWithoutIndexingIndexedName = new byte[] { 0x0f, 0x2b }; + + // Literal Header Field Never Indexed Representation - New Name + private static readonly byte[] _literalHeaderFieldNeverIndexedNewName = new byte[] { 0x10 }; + + // Literal Header Field Never Indexed Representation - Indexed Name - Index 58 (user-agent) + private static readonly byte[] _literalHeaderFieldNeverIndexedIndexedName = new byte[] { 0x1f, 0x2b }; + + private const string _userAgentString = "user-agent"; + + private static readonly byte[] _userAgentBytes = Encoding.ASCII.GetBytes(_userAgentString); + + private const string _headerNameString = "new-header"; + + private static readonly byte[] _headerNameBytes = Encoding.ASCII.GetBytes(_headerNameString); + + // n e w - h e a d e r * + // 10101000 10111110 00010110 10011100 10100011 10010000 10110110 01111111 + private static readonly byte[] _headerNameHuffmanBytes = new byte[] { 0xa8, 0xbe, 0x16, 0x9c, 0xa3, 0x90, 0xb6, 0x7f }; + + private const string _headerValueString = "value"; + + private static readonly byte[] _headerValueBytes = Encoding.ASCII.GetBytes(_headerValueString); + + // v a l u e * + // 11101110 00111010 00101101 00101111 + private static readonly byte[] _headerValueHuffmanBytes = new byte [] { 0xee, 0x3a, 0x2d, 0x2f }; + + private static readonly byte[] _headerName = new byte[] { (byte)_headerNameBytes.Length } + .Concat(_headerNameBytes) + .ToArray(); + + private static readonly byte[] _headerNameHuffman = new byte[] { (byte)(0x80 | _headerNameHuffmanBytes.Length) } + .Concat(_headerNameHuffmanBytes) + .ToArray(); + + private static readonly byte[] _headerValue = new byte[] { (byte)_headerValueBytes.Length } + .Concat(_headerValueBytes) + .ToArray(); + + private static readonly byte[] _headerValueHuffman = new byte[] { (byte)(0x80 | _headerValueHuffmanBytes.Length) } + .Concat(_headerValueHuffmanBytes) + .ToArray(); + + // & * + // 11111000 11111111 + private static readonly byte[] _huffmanLongPadding = new byte[] { 0x82, 0xf8, 0xff }; + + // EOS * + // 11111111 11111111 11111111 11111111 + private static readonly byte[] _huffmanEos = new byte[] { 0x84, 0xff, 0xff, 0xff, 0xff }; + + private readonly DynamicTable _dynamicTable; + private readonly HPackDecoder _decoder; + + private readonly Dictionary _decodedHeaders = new Dictionary(); + + public HPackDecoderTests() + { + _dynamicTable = new DynamicTable(DynamicTableInitialMaxSize); + _decoder = new HPackDecoder(DynamicTableInitialMaxSize, _dynamicTable); + } + + void IHttpHeadersHandler.OnHeader(Span name, Span value) + { + _decodedHeaders[name.GetAsciiStringNonNullCharacters()] = value.GetAsciiStringNonNullCharacters(); + } + + [Fact] + public void DecodesIndexedHeaderField_StaticTable() + { + _decoder.Decode(_indexedHeaderStatic, endHeaders: true, handler: this); + Assert.Equal("GET", _decodedHeaders[":method"]); + } + + [Fact] + public void DecodesIndexedHeaderField_DynamicTable() + { + // Add the header to the dynamic table + _dynamicTable.Insert(_headerNameBytes, _headerValueBytes); + + // Index it + _decoder.Decode(_indexedHeaderDynamic, endHeaders: true, handler: this); + Assert.Equal(_headerValueString, _decodedHeaders[_headerNameString]); + } + + [Fact] + public void DecodesIndexedHeaderField_OutOfRange_Error() + { + var exception = Assert.Throws(() => _decoder.Decode(_indexedHeaderDynamic, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackErrorIndexOutOfRange(62), exception.Message); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_NewName() + { + var encoded = _literalHeaderFieldWithIndexingNewName + .Concat(_headerName) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_NewName_HuffmanEncodedName() + { + var encoded = _literalHeaderFieldWithIndexingNewName + .Concat(_headerNameHuffman) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_NewName_HuffmanEncodedValue() + { + var encoded = _literalHeaderFieldWithIndexingNewName + .Concat(_headerName) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_NewName_HuffmanEncodedNameAndValue() + { + var encoded = _literalHeaderFieldWithIndexingNewName + .Concat(_headerNameHuffman) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_IndexedName() + { + var encoded = _literalHeaderFieldWithIndexingIndexedName + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_IndexedName_HuffmanEncodedValue() + { + var encoded = _literalHeaderFieldWithIndexingIndexedName + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithIncrementalIndexing_IndexedName_OutOfRange_Error() + { + // 01 (Literal Header Field without Indexing Representation) + // 11 1110 (Indexed Name - Index 62 encoded with 6-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + // Index 62 is the first entry in the dynamic table. If there's nothing there, the decoder should throw. + + var exception = Assert.Throws(() => _decoder.Decode(new byte[] { 0x7e }, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackErrorIndexOutOfRange(62), exception.Message); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_NewName() + { + var encoded = _literalHeaderFieldWithoutIndexingNewName + .Concat(_headerName) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_NewName_HuffmanEncodedName() + { + var encoded = _literalHeaderFieldWithoutIndexingNewName + .Concat(_headerNameHuffman) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_NewName_HuffmanEncodedValue() + { + var encoded = _literalHeaderFieldWithoutIndexingNewName + .Concat(_headerName) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_NewName_HuffmanEncodedNameAndValue() + { + var encoded = _literalHeaderFieldWithoutIndexingNewName + .Concat(_headerNameHuffman) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_IndexedName() + { + var encoded = _literalHeaderFieldWithoutIndexingIndexedName + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_IndexedName_HuffmanEncodedValue() + { + var encoded = _literalHeaderFieldWithoutIndexingIndexedName + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldWithoutIndexing_IndexedName_OutOfRange_Error() + { + // 0000 (Literal Header Field without Indexing Representation) + // 1111 0010 1111 (Indexed Name - Index 62 encoded with 4-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + // Index 62 is the first entry in the dynamic table. If there's nothing there, the decoder should throw. + + var exception = Assert.Throws(() => _decoder.Decode(new byte[] { 0x0f, 0x2f }, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackErrorIndexOutOfRange(62), exception.Message); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_NewName() + { + var encoded = _literalHeaderFieldNeverIndexedNewName + .Concat(_headerName) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_NewName_HuffmanEncodedName() + { + var encoded = _literalHeaderFieldNeverIndexedNewName + .Concat(_headerNameHuffman) + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_NewName_HuffmanEncodedValue() + { + var encoded = _literalHeaderFieldNeverIndexedNewName + .Concat(_headerName) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_NewName_HuffmanEncodedNameAndValue() + { + var encoded = _literalHeaderFieldNeverIndexedNewName + .Concat(_headerNameHuffman) + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _headerNameString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_IndexedName() + { + // 0001 (Literal Header Field Never Indexed Representation) + // 1111 0010 1011 (Indexed Name - Index 58 encoded with 4-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + // Concatenated with value bytes + var encoded = _literalHeaderFieldNeverIndexedIndexedName + .Concat(_headerValue) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_IndexedName_HuffmanEncodedValue() + { + // 0001 (Literal Header Field Never Indexed Representation) + // 1111 0010 1011 (Indexed Name - Index 58 encoded with 4-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + // Concatenated with Huffman encoded value bytes + var encoded = _literalHeaderFieldNeverIndexedIndexedName + .Concat(_headerValueHuffman) + .ToArray(); + + TestDecodeWithoutIndexing(encoded, _userAgentString, _headerValueString); + } + + [Fact] + public void DecodesLiteralHeaderFieldNeverIndexed_IndexedName_OutOfRange_Error() + { + // 0001 (Literal Header Field Never Indexed Representation) + // 1111 0010 1111 (Indexed Name - Index 62 encoded with 4-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + // Index 62 is the first entry in the dynamic table. If there's nothing there, the decoder should throw. + + var exception = Assert.Throws(() => _decoder.Decode(new byte[] { 0x1f, 0x2f }, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackErrorIndexOutOfRange(62), exception.Message); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesDynamicTableSizeUpdate() + { + // 001 (Dynamic Table Size Update) + // 11110 (30 encoded with 5-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + + Assert.Equal(DynamicTableInitialMaxSize, _dynamicTable.MaxSize); + + _decoder.Decode(new byte[] { 0x3e }, endHeaders: true, handler: this); + + Assert.Equal(30, _dynamicTable.MaxSize); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesDynamicTableSizeUpdate_GreaterThanLimit_Error() + { + // 001 (Dynamic Table Size Update) + // 11111 11100010 00011111 (4097 encoded with 5-bit prefix - see http://httpwg.org/specs/rfc7541.html#integer.representation) + + Assert.Equal(DynamicTableInitialMaxSize, _dynamicTable.MaxSize); + + var exception = Assert.Throws(() => _decoder.Decode(new byte[] { 0x3f, 0xe2, 0x1f }, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackErrorDynamicTableSizeUpdateTooLarge(4097, DynamicTableInitialMaxSize), exception.Message); + Assert.Empty(_decodedHeaders); + } + + [Fact] + public void DecodesStringLength_GreaterThanLimit_Error() + { + var encoded = _literalHeaderFieldWithoutIndexingNewName + .Concat(new byte[] { 0xff, 0x82, 0x1f }) // 4097 encoded with 7-bit prefix + .ToArray(); + + var exception = Assert.Throws(() => _decoder.Decode(encoded, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.FormatHPackStringLengthTooLarge(4097, HPackDecoder.MaxStringOctets), exception.Message); + Assert.Empty(_decodedHeaders); + } + + public static readonly TheoryData _incompleteHeaderBlockData = new TheoryData + { + // Indexed Header Field Representation - incomplete index encoding + new byte[] { 0xff }, + + // Literal Header Field with Incremental Indexing Representation - New Name - incomplete header name length encoding + new byte[] { 0x40, 0x7f }, + + // Literal Header Field with Incremental Indexing Representation - New Name - incomplete header name + new byte[] { 0x40, 0x01 }, + new byte[] { 0x40, 0x02, 0x61 }, + + // Literal Header Field with Incremental Indexing Representation - New Name - incomplete header value length encoding + new byte[] { 0x40, 0x01, 0x61, 0x7f }, + + // Literal Header Field with Incremental Indexing Representation - New Name - incomplete header value + new byte[] { 0x40, 0x01, 0x61, 0x01 }, + new byte[] { 0x40, 0x01, 0x61, 0x02, 0x61 }, + + // Literal Header Field with Incremental Indexing Representation - Indexed Name - incomplete index encoding + new byte[] { 0x7f }, + + // Literal Header Field with Incremental Indexing Representation - Indexed Name - incomplete header value length encoding + new byte[] { 0x7a, 0xff }, + + // Literal Header Field with Incremental Indexing Representation - Indexed Name - incomplete header value + new byte[] { 0x7a, 0x01 }, + new byte[] { 0x7a, 0x02, 0x61 }, + + // Literal Header Field without Indexing - New Name - incomplete header name length encoding + new byte[] { 0x00, 0xff }, + + // Literal Header Field without Indexing - New Name - incomplete header name + new byte[] { 0x00, 0x01 }, + new byte[] { 0x00, 0x02, 0x61 }, + + // Literal Header Field without Indexing - New Name - incomplete header value length encoding + new byte[] { 0x00, 0x01, 0x61, 0xff }, + + // Literal Header Field without Indexing - New Name - incomplete header value + new byte[] { 0x00, 0x01, 0x61, 0x01 }, + new byte[] { 0x00, 0x01, 0x61, 0x02, 0x61 }, + + // Literal Header Field without Indexing Representation - Indexed Name - incomplete index encoding + new byte[] { 0x0f }, + + // Literal Header Field without Indexing Representation - Indexed Name - incomplete header value length encoding + new byte[] { 0x02, 0xff }, + + // Literal Header Field without Indexing Representation - Indexed Name - incomplete header value + new byte[] { 0x02, 0x01 }, + new byte[] { 0x02, 0x02, 0x61 }, + + // Literal Header Field Never Indexed - New Name - incomplete header name length encoding + new byte[] { 0x10, 0xff }, + + // Literal Header Field Never Indexed - New Name - incomplete header name + new byte[] { 0x10, 0x01 }, + new byte[] { 0x10, 0x02, 0x61 }, + + // Literal Header Field Never Indexed - New Name - incomplete header value length encoding + new byte[] { 0x10, 0x01, 0x61, 0xff }, + + // Literal Header Field Never Indexed - New Name - incomplete header value + new byte[] { 0x10, 0x01, 0x61, 0x01 }, + new byte[] { 0x10, 0x01, 0x61, 0x02, 0x61 }, + + // Literal Header Field Never Indexed Representation - Indexed Name - incomplete index encoding + new byte[] { 0x1f }, + + // Literal Header Field Never Indexed Representation - Indexed Name - incomplete header value length encoding + new byte[] { 0x12, 0xff }, + + // Literal Header Field Never Indexed Representation - Indexed Name - incomplete header value + new byte[] { 0x12, 0x01 }, + new byte[] { 0x12, 0x02, 0x61 }, + + // Dynamic Table Size Update - incomplete max size encoding + new byte[] { 0x3f } + }; + + [Theory] + [MemberData(nameof(_incompleteHeaderBlockData))] + public void DecodesIncompleteHeaderBlock_Error(byte[] encoded) + { + var exception = Assert.Throws(() => _decoder.Decode(encoded, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.HPackErrorIncompleteHeaderBlock, exception.Message); + Assert.Empty(_decodedHeaders); + } + + public static readonly TheoryData _huffmanDecodingErrorData = new TheoryData + { + // Invalid Huffman encoding in header name + + _literalHeaderFieldWithIndexingNewName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldWithIndexingNewName.Concat(_huffmanEos).ToArray(), + + _literalHeaderFieldWithoutIndexingNewName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldWithoutIndexingNewName.Concat(_huffmanEos).ToArray(), + + _literalHeaderFieldNeverIndexedNewName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldNeverIndexedNewName.Concat(_huffmanEos).ToArray(), + + // Invalid Huffman encoding in header value + + _literalHeaderFieldWithIndexingIndexedName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldWithIndexingIndexedName.Concat(_huffmanEos).ToArray(), + + _literalHeaderFieldWithoutIndexingIndexedName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldWithoutIndexingIndexedName.Concat(_huffmanEos).ToArray(), + + _literalHeaderFieldNeverIndexedIndexedName.Concat(_huffmanLongPadding).ToArray(), + _literalHeaderFieldNeverIndexedIndexedName.Concat(_huffmanEos).ToArray() + }; + + [Theory] + [MemberData(nameof(_huffmanDecodingErrorData))] + public void WrapsHuffmanDecodingExceptionInHPackDecodingException(byte[] encoded) + { + var exception = Assert.Throws(() => _decoder.Decode(encoded, endHeaders: true, handler: this)); + Assert.Equal(CoreStrings.HPackHuffmanError, exception.Message); + Assert.IsType(exception.InnerException); + Assert.Empty(_decodedHeaders); + } + + private void TestDecodeWithIndexing(byte[] encoded, string expectedHeaderName, string expectedHeaderValue) + { + TestDecode(encoded, expectedHeaderName, expectedHeaderValue, expectDynamicTableEntry: true); + } + + private void TestDecodeWithoutIndexing(byte[] encoded, string expectedHeaderName, string expectedHeaderValue) + { + TestDecode(encoded, expectedHeaderName, expectedHeaderValue, expectDynamicTableEntry: false); + } + + private void TestDecode(byte[] encoded, string expectedHeaderName, string expectedHeaderValue, bool expectDynamicTableEntry) + { + Assert.Equal(0, _dynamicTable.Count); + Assert.Equal(0, _dynamicTable.Size); + + _decoder.Decode(encoded, endHeaders: true, handler: this); + + Assert.Equal(expectedHeaderValue, _decodedHeaders[expectedHeaderName]); + + if (expectDynamicTableEntry) + { + Assert.Equal(1, _dynamicTable.Count); + Assert.Equal(expectedHeaderName, Encoding.ASCII.GetString(_dynamicTable[0].Name)); + Assert.Equal(expectedHeaderValue, Encoding.ASCII.GetString(_dynamicTable[0].Value)); + Assert.Equal(expectedHeaderName.Length + expectedHeaderValue.Length + 32, _dynamicTable.Size); + } + else + { + Assert.Equal(0, _dynamicTable.Count); + Assert.Equal(0, _dynamicTable.Size); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HPackEncoderTests.cs b/src/Servers/Kestrel/Core/test/HPackEncoderTests.cs new file mode 100644 index 0000000000..20313c6a90 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HPackEncoderTests.cs @@ -0,0 +1,129 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HPackEncoderTests + { + [Fact] + public void EncodesHeadersInSinglePayloadWhenSpaceAvailable() + { + var encoder = new HPackEncoder(); + + var statusCode = 200; + var headers = new [] + { + new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), + new KeyValuePair("content-type", "text/html; charset=utf-8"), + new KeyValuePair("server", "Kestrel") + }; + + var expectedPayload = new byte[] + { + 0x88, 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, + 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, + 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, + 0x20, 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, + 0x30, 0x20, 0x47, 0x4d, 0x54, 0x00, 0x0c, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x74, 0x65, 0x78, 0x74, + 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x3b, 0x20, 0x63, + 0x68, 0x61, 0x72, 0x73, 0x65, 0x74, 0x3d, 0x75, + 0x74, 0x66, 0x2d, 0x38, 0x00, 0x06, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x07, 0x4b, 0x65, 0x73, + 0x74, 0x72, 0x65, 0x6c + }; + + var payload = new byte[1024]; + Assert.True(encoder.BeginEncode(statusCode, headers, payload, out var length)); + Assert.Equal(expectedPayload.Length, length); + + for (var i = 0; i < length; i++) + { + Assert.True(expectedPayload[i] == payload[i], $"{expectedPayload[i]} != {payload[i]} at {i} (len {length})"); + } + + Assert.Equal(expectedPayload, new ArraySegment(payload, 0, length)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodesHeadersInMultiplePayloadsWhenSpaceNotAvailable(bool exactSize) + { + var encoder = new HPackEncoder(); + + var statusCode = 200; + var headers = new [] + { + new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), + new KeyValuePair("content-type", "text/html; charset=utf-8"), + new KeyValuePair("server", "Kestrel") + }; + + var expectedStatusCodePayload = new byte[] + { + 0x88 + }; + + var expectedDateHeaderPayload = new byte[] + { + 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, + 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, 0x4a, + 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, 0x20, + 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, 0x30, + 0x20, 0x47, 0x4d, 0x54 + }; + + var expectedContentTypeHeaderPayload = new byte[] + { + 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, + 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x18, 0x74, + 0x65, 0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, + 0x3b, 0x20, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, + 0x74, 0x3d, 0x75, 0x74, 0x66, 0x2d, 0x38 + }; + + var expectedServerHeaderPayload = new byte[] + { + 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x07, 0x4b, 0x65, 0x73, 0x74, 0x72, 0x65, 0x6c + }; + + Span payload = new byte[1024]; + var offset = 0; + + // When !exactSize, slices are one byte short of fitting the next header + var sliceLength = expectedStatusCodePayload.Length + (exactSize ? 0 : expectedDateHeaderPayload.Length - 1); + Assert.False(encoder.BeginEncode(statusCode, headers, payload.Slice(offset, sliceLength), out var length)); + Assert.Equal(expectedStatusCodePayload.Length, length); + Assert.Equal(expectedStatusCodePayload, payload.Slice(0, length).ToArray()); + + offset += length; + + sliceLength = expectedDateHeaderPayload.Length + (exactSize ? 0 : expectedContentTypeHeaderPayload.Length - 1); + Assert.False(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedDateHeaderPayload.Length, length); + Assert.Equal(expectedDateHeaderPayload, payload.Slice(offset, length).ToArray()); + + offset += length; + + sliceLength = expectedContentTypeHeaderPayload.Length + (exactSize ? 0 : expectedServerHeaderPayload.Length - 1); + Assert.False(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedContentTypeHeaderPayload.Length, length); + Assert.Equal(expectedContentTypeHeaderPayload, payload.Slice(offset, length).ToArray()); + + offset += length; + + sliceLength = expectedServerHeaderPayload.Length; + Assert.True(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedServerHeaderPayload.Length, length); + Assert.Equal(expectedServerHeaderPayload, payload.Slice(offset, length).ToArray()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HeartbeatTests.cs b/src/Servers/Kestrel/Core/test/HeartbeatTests.cs new file mode 100644 index 0000000000..86fc9093f3 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HeartbeatTests.cs @@ -0,0 +1,100 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HeartbeatTests + { + [Fact] + public void HeartbeatIntervalIsOneSecond() + { + Assert.Equal(TimeSpan.FromSeconds(1), Heartbeat.Interval); + } + + [Fact] + public void BlockedHeartbeatDoesntCauseOverlapsAndIsLoggedAsError() + { + var systemClock = new MockSystemClock(); + var heartbeatHandler = new Mock(); + var debugger = new Mock(); + var kestrelTrace = new Mock(); + var handlerMre = new ManualResetEventSlim(); + var traceMre = new ManualResetEventSlim(); + var onHeartbeatTasks = new Task[2]; + + heartbeatHandler.Setup(h => h.OnHeartbeat(systemClock.UtcNow)).Callback(() => handlerMre.Wait()); + debugger.Setup(d => d.IsAttached).Returns(false); + kestrelTrace.Setup(t => t.HeartbeatSlow(Heartbeat.Interval, systemClock.UtcNow)).Callback(() => traceMre.Set()); + + using (var heartbeat = new Heartbeat(new[] { heartbeatHandler.Object }, systemClock, debugger.Object, kestrelTrace.Object)) + { + onHeartbeatTasks[0] = Task.Run(() => heartbeat.OnHeartbeat()); + onHeartbeatTasks[1] = Task.Run(() => heartbeat.OnHeartbeat()); + Assert.True(traceMre.Wait(TimeSpan.FromSeconds(10))); + } + + handlerMre.Set(); + Task.WaitAll(onHeartbeatTasks); + + heartbeatHandler.Verify(h => h.OnHeartbeat(systemClock.UtcNow), Times.Once()); + kestrelTrace.Verify(t => t.HeartbeatSlow(Heartbeat.Interval, systemClock.UtcNow), Times.Once()); + } + + [Fact] + public void BlockedHeartbeatIsNotLoggedAsErrorIfDebuggerAttached() + { + var systemClock = new MockSystemClock(); + var heartbeatHandler = new Mock(); + var debugger = new Mock(); + var kestrelTrace = new Mock(); + var handlerMre = new ManualResetEventSlim(); + var traceMre = new ManualResetEventSlim(); + var onHeartbeatTasks = new Task[2]; + + heartbeatHandler.Setup(h => h.OnHeartbeat(systemClock.UtcNow)).Callback(() => handlerMre.Wait()); + debugger.Setup(d => d.IsAttached).Returns(true); + kestrelTrace.Setup(t => t.HeartbeatSlow(Heartbeat.Interval, systemClock.UtcNow)).Callback(() => traceMre.Set()); + + using (var heartbeat = new Heartbeat(new[] { heartbeatHandler.Object }, systemClock, debugger.Object, kestrelTrace.Object, TimeSpan.FromSeconds(0.01))) + { + onHeartbeatTasks[0] = Task.Run(() => heartbeat.OnHeartbeat()); + onHeartbeatTasks[1] = Task.Run(() => heartbeat.OnHeartbeat()); + Assert.False(traceMre.Wait(TimeSpan.FromSeconds(2))); + } + + handlerMre.Set(); + Task.WaitAll(onHeartbeatTasks); + + heartbeatHandler.Verify(h => h.OnHeartbeat(systemClock.UtcNow), Times.Once()); + kestrelTrace.Verify(t => t.HeartbeatSlow(Heartbeat.Interval, systemClock.UtcNow), Times.Never()); + } + + [Fact] + public void ExceptionFromHeartbeatHandlerIsLoggedAsError() + { + var systemClock = new MockSystemClock(); + var heartbeatHandler = new Mock(); + var kestrelTrace = new TestKestrelTrace(); + var ex = new Exception(); + + heartbeatHandler.Setup(h => h.OnHeartbeat(systemClock.UtcNow)).Throws(ex); + + using (var heartbeat = new Heartbeat(new[] { heartbeatHandler.Object }, systemClock, DebuggerWrapper.Singleton, kestrelTrace)) + { + heartbeat.OnHeartbeat(); + } + + Assert.Equal(ex, kestrelTrace.Logger.Messages.Single(message => message.LogLevel == LogLevel.Error).Exception); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs b/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs new file mode 100644 index 0000000000..42f689c125 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/Http1ConnectionTests.cs @@ -0,0 +1,1009 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class Http1ConnectionTests : IDisposable + { + private readonly IDuplexPipe _transport; + private readonly IDuplexPipe _application; + private readonly TestHttp1Connection _http1Connection; + private readonly ServiceContext _serviceContext; + private readonly Http1ConnectionContext _http1ConnectionContext; + private readonly MemoryPool _pipelineFactory; + private SequencePosition _consumed; + private SequencePosition _examined; + private Mock _timeoutControl; + + public Http1ConnectionTests() + { + _pipelineFactory = KestrelMemoryPool.Create(); + var options = new PipeOptions(_pipelineFactory, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + _transport = pair.Transport; + _application = pair.Application; + + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + + _serviceContext = new TestServiceContext(); + _timeoutControl = new Mock(); + _http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = _serviceContext, + ConnectionContext = Mock.Of(), + ConnectionFeatures = connectionFeatures, + MemoryPool = _pipelineFactory, + TimeoutControl = _timeoutControl.Object, + Application = pair.Application, + Transport = pair.Transport + }; + + _http1Connection = new TestHttp1Connection(_http1ConnectionContext); + _http1Connection.Reset(); + } + + public void Dispose() + { + _transport.Input.Complete(); + _transport.Output.Complete(); + + _application.Input.Complete(); + _application.Output.Complete(); + + _pipelineFactory.Dispose(); + } + + [Fact] + public async Task TakeMessageHeadersThrowsWhenHeadersExceedTotalSizeLimit() + { + const string headerLine = "Header: value\r\n"; + _serviceContext.ServerOptions.Limits.MaxRequestHeadersTotalSize = headerLine.Length - 1; + _http1Connection.Reset(); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _http1Connection.TakeMessageHeaders(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.BadRequest_HeadersExceedMaxTotalSize, exception.Message); + Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); + } + + [Fact] + public async Task TakeMessageHeadersThrowsWhenHeadersExceedCountLimit() + { + const string headerLines = "Header-1: value1\r\nHeader-2: value2\r\n"; + _serviceContext.ServerOptions.Limits.MaxRequestHeaderCount = 1; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLines}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => _http1Connection.TakeMessageHeaders(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.BadRequest_TooManyHeaders, exception.Message); + Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); + } + + [Fact] + public void ResetResetsScheme() + { + _http1Connection.Scheme = "https"; + + // Act + _http1Connection.Reset(); + + // Assert + Assert.Equal("http", ((IFeatureCollection)_http1Connection).Get().Scheme); + } + + [Fact] + public void ResetResetsTraceIdentifier() + { + _http1Connection.TraceIdentifier = "xyz"; + + _http1Connection.Reset(); + + var nextId = ((IFeatureCollection)_http1Connection).Get().TraceIdentifier; + Assert.NotEqual("xyz", nextId); + + _http1Connection.Reset(); + var secondId = ((IFeatureCollection)_http1Connection).Get().TraceIdentifier; + Assert.NotEqual(nextId, secondId); + } + + [Fact] + public void ResetResetsMinRequestBodyDataRate() + { + _http1Connection.MinRequestBodyDataRate = new MinDataRate(bytesPerSecond: 1, gracePeriod: TimeSpan.MaxValue); + + _http1Connection.Reset(); + + Assert.Same(_serviceContext.ServerOptions.Limits.MinRequestBodyDataRate, _http1Connection.MinRequestBodyDataRate); + } + + [Fact] + public void ResetResetsMinResponseDataRate() + { + _http1Connection.MinResponseDataRate = new MinDataRate(bytesPerSecond: 1, gracePeriod: TimeSpan.MaxValue); + + _http1Connection.Reset(); + + Assert.Same(_serviceContext.ServerOptions.Limits.MinResponseDataRate, _http1Connection.MinResponseDataRate); + } + + [Fact] + public void TraceIdentifierCountsRequestsPerHttp1Connection() + { + var connectionId = _http1ConnectionContext.ConnectionId; + var feature = ((IFeatureCollection)_http1Connection).Get(); + // Reset() is called once in the test ctor + var count = 1; + void Reset() + { + _http1Connection.Reset(); + count++; + } + + var nextId = feature.TraceIdentifier; + Assert.Equal($"{connectionId}:00000001", nextId); + + Reset(); + var secondId = feature.TraceIdentifier; + Assert.Equal($"{connectionId}:00000002", secondId); + + var big = 1_000_000; + while (big-- > 0) Reset(); + Assert.Equal($"{connectionId}:{count:X8}", feature.TraceIdentifier); + } + + [Fact] + public void TraceIdentifierGeneratesWhenNull() + { + _http1Connection.TraceIdentifier = null; + var id = _http1Connection.TraceIdentifier; + Assert.NotNull(id); + Assert.Equal(id, _http1Connection.TraceIdentifier); + + _http1Connection.Reset(); + Assert.NotEqual(id, _http1Connection.TraceIdentifier); + } + + [Fact] + public async Task ResetResetsHeaderLimits() + { + const string headerLine1 = "Header-1: value1\r\n"; + const string headerLine2 = "Header-2: value2\r\n"; + + var options = new KestrelServerOptions(); + options.Limits.MaxRequestHeadersTotalSize = headerLine1.Length; + options.Limits.MaxRequestHeaderCount = 1; + _serviceContext.ServerOptions = options; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine1}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var takeMessageHeaders = _http1Connection.TakeMessageHeaders(readableBuffer, out _consumed, out _examined); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.True(takeMessageHeaders); + Assert.Equal(1, _http1Connection.RequestHeaders.Count); + Assert.Equal("value1", _http1Connection.RequestHeaders["Header-1"]); + + _http1Connection.Reset(); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine2}\r\n")); + readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + takeMessageHeaders = _http1Connection.TakeMessageHeaders(readableBuffer, out _consumed, out _examined); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.True(takeMessageHeaders); + Assert.Equal(1, _http1Connection.RequestHeaders.Count); + Assert.Equal("value2", _http1Connection.RequestHeaders["Header-2"]); + } + + [Fact] + public async Task ThrowsWhenStatusCodeIsSetAfterResponseStarted() + { + // Act + await _http1Connection.WriteAsync(new ArraySegment(new byte[1])); + + // Assert + Assert.True(_http1Connection.HasResponseStarted); + Assert.Throws(() => ((IHttpResponseFeature)_http1Connection).StatusCode = StatusCodes.Status404NotFound); + } + + [Fact] + public async Task ThrowsWhenReasonPhraseIsSetAfterResponseStarted() + { + // Act + await _http1Connection.WriteAsync(new ArraySegment(new byte[1])); + + // Assert + Assert.True(_http1Connection.HasResponseStarted); + Assert.Throws(() => ((IHttpResponseFeature)_http1Connection).ReasonPhrase = "Reason phrase"); + } + + [Fact] + public async Task ThrowsWhenOnStartingIsSetAfterResponseStarted() + { + await _http1Connection.WriteAsync(new ArraySegment(new byte[1])); + + // Act/Assert + Assert.True(_http1Connection.HasResponseStarted); + Assert.Throws(() => ((IHttpResponseFeature)_http1Connection).OnStarting(_ => Task.CompletedTask, null)); + } + + [Theory] + [MemberData(nameof(MinDataRateData))] + public void ConfiguringIHttpMinRequestBodyDataRateFeatureSetsMinRequestBodyDataRate(MinDataRate minDataRate) + { + ((IFeatureCollection)_http1Connection).Get().MinDataRate = minDataRate; + + Assert.Same(minDataRate, _http1Connection.MinRequestBodyDataRate); + } + + [Theory] + [MemberData(nameof(MinDataRateData))] + public void ConfiguringIHttpMinResponseDataRateFeatureSetsMinResponseDataRate(MinDataRate minDataRate) + { + ((IFeatureCollection)_http1Connection).Get().MinDataRate = minDataRate; + + Assert.Same(minDataRate, _http1Connection.MinResponseDataRate); + } + + [Fact] + public void ResetResetsRequestHeaders() + { + // Arrange + var originalRequestHeaders = _http1Connection.RequestHeaders; + _http1Connection.RequestHeaders = new HttpRequestHeaders(); + + // Act + _http1Connection.Reset(); + + // Assert + Assert.Same(originalRequestHeaders, _http1Connection.RequestHeaders); + } + + [Fact] + public void ResetResetsResponseHeaders() + { + // Arrange + var originalResponseHeaders = _http1Connection.ResponseHeaders; + _http1Connection.ResponseHeaders = new HttpResponseHeaders(); + + // Act + _http1Connection.Reset(); + + // Assert + Assert.Same(originalResponseHeaders, _http1Connection.ResponseHeaders); + } + + [Fact] + public void InitializeStreamsResetsStreams() + { + // Arrange + var messageBody = Http1MessageBody.For(Kestrel.Core.Internal.Http.HttpVersion.Http11, (HttpRequestHeaders)_http1Connection.RequestHeaders, _http1Connection); + _http1Connection.InitializeStreams(messageBody); + + var originalRequestBody = _http1Connection.RequestBody; + var originalResponseBody = _http1Connection.ResponseBody; + _http1Connection.RequestBody = new MemoryStream(); + _http1Connection.ResponseBody = new MemoryStream(); + + // Act + _http1Connection.InitializeStreams(messageBody); + + // Assert + Assert.Same(originalRequestBody, _http1Connection.RequestBody); + Assert.Same(originalResponseBody, _http1Connection.ResponseBody); + } + + [Theory] + [MemberData(nameof(RequestLineValidData))] + public async Task TakeStartLineSetsHttpProtocolProperties( + string requestLine, + string expectedMethod, + string expectedRawTarget, + // This warns that theory methods should use all of their parameters, + // but this method is using a shared data collection with HttpParserTests.ParsesRequestLine and others. +#pragma warning disable xUnit1026 + string expectedRawPath, +#pragma warning restore xUnit1026 + string expectedDecodedPath, + string expectedQueryString, + string expectedHttpVersion) + { + var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); + await _application.Output.WriteAsync(requestLineBytes); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var returnValue = _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.True(returnValue); + Assert.Equal(expectedMethod, ((IHttpRequestFeature)_http1Connection).Method); + Assert.Equal(expectedRawTarget, _http1Connection.RawTarget); + Assert.Equal(expectedDecodedPath, _http1Connection.Path); + Assert.Equal(expectedQueryString, _http1Connection.QueryString); + Assert.Equal(expectedHttpVersion, _http1Connection.HttpVersion); + } + + [Theory] + [MemberData(nameof(RequestLineDotSegmentData))] + public async Task TakeStartLineRemovesDotSegmentsFromTarget( + string requestLine, + string expectedRawTarget, + string expectedDecodedPath, + string expectedQueryString) + { + var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); + await _application.Output.WriteAsync(requestLineBytes); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var returnValue = _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.True(returnValue); + Assert.Equal(expectedRawTarget, _http1Connection.RawTarget); + Assert.Equal(expectedDecodedPath, _http1Connection.Path); + Assert.Equal(expectedQueryString, _http1Connection.QueryString); + } + + [Fact] + public async Task ParseRequestStartsRequestHeadersTimeoutOnFirstByteAvailable() + { + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("G")); + + _http1Connection.ParseRequest((await _transport.Input.ReadAsync()).Buffer, out _consumed, out _examined); + _transport.Input.AdvanceTo(_consumed, _examined); + + var expectedRequestHeadersTimeout = _serviceContext.ServerOptions.Limits.RequestHeadersTimeout.Ticks; + _timeoutControl.Verify(cc => cc.ResetTimeout(expectedRequestHeadersTimeout, TimeoutAction.SendTimeoutResponse)); + } + + [Fact] + public async Task TakeStartLineThrowsWhenTooLong() + { + _serviceContext.ServerOptions.Limits.MaxRequestLineSize = "GET / HTTP/1.1\r\n".Length; + + var requestLineBytes = Encoding.ASCII.GetBytes("GET /a HTTP/1.1\r\n"); + await _application.Output.WriteAsync(requestLineBytes); + + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + var exception = Assert.Throws(() => _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.BadRequest_RequestLineTooLong, exception.Message); + Assert.Equal(StatusCodes.Status414UriTooLong, exception.StatusCode); + } + + [Theory] + [MemberData(nameof(TargetWithEncodedNullCharData))] + public async Task TakeStartLineThrowsOnEncodedNullCharInTarget(string target) + { + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target), exception.Message); + } + + [Theory] + [MemberData(nameof(TargetWithNullCharData))] + public async Task TakeStartLineThrowsOnNullCharInTarget(string target) + { + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); + } + + [Theory] + [MemberData(nameof(MethodWithNullCharData))] + public async Task TakeStartLineThrowsOnNullCharInMethod(string method) + { + var requestLine = $"{method} / HTTP/1.1\r\n"; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(requestLine.EscapeNonPrintable()), exception.Message); + } + + [Theory] + [MemberData(nameof(QueryStringWithNullCharData))] + public async Task TakeStartLineThrowsOnNullCharInQueryString(string queryString) + { + var target = $"/{queryString}"; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); + } + + [Theory] + [MemberData(nameof(TargetInvalidData))] + public async Task TakeStartLineThrowsWhenRequestTargetIsInvalid(string method, string target) + { + var requestLine = $"{method} {target} HTTP/1.1\r\n"; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); + } + + [Theory] + [MemberData(nameof(MethodNotAllowedTargetData))] + public async Task TakeStartLineThrowsWhenMethodNotAllowed(string requestLine, HttpMethod allowedMethod) + { + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(405, exception.StatusCode); + Assert.Equal(CoreStrings.BadRequest_MethodNotAllowed, exception.Message); + Assert.Equal(HttpUtilities.MethodToString(allowedMethod), exception.AllowedHeader); + } + + [Fact] + public void ProcessRequestsAsyncEnablesKeepAliveTimeout() + { + var requestProcessingTask = _http1Connection.ProcessRequestsAsync(null); + + var expectedKeepAliveTimeout = _serviceContext.ServerOptions.Limits.KeepAliveTimeout.Ticks; + _timeoutControl.Verify(cc => cc.SetTimeout(expectedKeepAliveTimeout, TimeoutAction.StopProcessingNextRequest)); + + _http1Connection.StopProcessingNextRequest(); + _application.Output.Complete(); + + requestProcessingTask.Wait(); + } + + [Fact] + public async Task WriteThrowsForNonBodyResponse() + { + // Arrange + ((IHttpResponseFeature)_http1Connection).StatusCode = StatusCodes.Status304NotModified; + + // Act/Assert + await Assert.ThrowsAsync(() => _http1Connection.WriteAsync(new ArraySegment(new byte[1]))); + } + + [Fact] + public async Task WriteAsyncThrowsForNonBodyResponse() + { + // Arrange + _http1Connection.HttpVersion = "HTTP/1.1"; + ((IHttpResponseFeature)_http1Connection).StatusCode = StatusCodes.Status304NotModified; + + // Act/Assert + await Assert.ThrowsAsync(() => _http1Connection.WriteAsync(new ArraySegment(new byte[1]), default(CancellationToken))); + } + + [Fact] + public async Task WriteDoesNotThrowForHeadResponse() + { + // Arrange + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.Method = HttpMethod.Head; + + // Act/Assert + await _http1Connection.WriteAsync(new ArraySegment(new byte[1])); + } + + [Fact] + public async Task WriteAsyncDoesNotThrowForHeadResponse() + { + // Arrange + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.Method = HttpMethod.Head; + + // Act/Assert + await _http1Connection.WriteAsync(new ArraySegment(new byte[1]), default(CancellationToken)); + } + + [Fact] + public async Task ManuallySettingTransferEncodingThrowsForHeadResponse() + { + // Arrange + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.Method = HttpMethod.Head; + + // Act + _http1Connection.ResponseHeaders.Add("Transfer-Encoding", "chunked"); + + // Assert + await Assert.ThrowsAsync(() => _http1Connection.FlushAsync()); + } + + [Fact] + public async Task ManuallySettingTransferEncodingThrowsForNoBodyResponse() + { + // Arrange + _http1Connection.HttpVersion = "HTTP/1.1"; + ((IHttpResponseFeature)_http1Connection).StatusCode = StatusCodes.Status304NotModified; + + // Act + _http1Connection.ResponseHeaders.Add("Transfer-Encoding", "chunked"); + + // Assert + await Assert.ThrowsAsync(() => _http1Connection.FlushAsync()); + } + + [Fact] + public async Task RequestProcessingTaskIsUnwrapped() + { + var requestProcessingTask = _http1Connection.ProcessRequestsAsync(null); + + var data = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await _application.Output.WriteAsync(data); + + _http1Connection.StopProcessingNextRequest(); + Assert.IsNotType>(requestProcessingTask); + + await requestProcessingTask.DefaultTimeout(); + _application.Output.Complete(); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteWithContentLength() + { + _http1Connection.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _http1Connection.RequestAborted.WaitHandle; + + foreach (var ch in "hello, worl") + { + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)ch })); + Assert.Same(original, _http1Connection.RequestAborted.WaitHandle); + } + + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)'d' })); + Assert.NotSame(original, _http1Connection.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncWithContentLength() + { + _http1Connection.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _http1Connection.RequestAborted.WaitHandle; + + foreach (var ch in "hello, worl") + { + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)ch }), default(CancellationToken)); + Assert.Same(original, _http1Connection.RequestAborted.WaitHandle); + } + + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)'d' }), default(CancellationToken)); + Assert.NotSame(original, _http1Connection.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncAwaitedWithContentLength() + { + _http1Connection.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _http1Connection.RequestAborted.WaitHandle; + + // Only first write can be WriteAsyncAwaited + var startingTask = _http1Connection.InitializeResponseAwaited(Task.CompletedTask, 1); + await _http1Connection.WriteAsyncAwaited(startingTask, new ArraySegment(new[] { (byte)'h' }), default(CancellationToken)); + Assert.Same(original, _http1Connection.RequestAborted.WaitHandle); + + foreach (var ch in "ello, worl") + { + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)ch }), default(CancellationToken)); + Assert.Same(original, _http1Connection.RequestAborted.WaitHandle); + } + + await _http1Connection.WriteAsync(new ArraySegment(new[] { (byte)'d' }), default(CancellationToken)); + Assert.NotSame(original, _http1Connection.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteWithChunkedEncoding() + { + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _http1Connection.RequestAborted.WaitHandle; + + _http1Connection.HttpVersion = "HTTP/1.1"; + await _http1Connection.WriteAsync(new ArraySegment(Encoding.ASCII.GetBytes("hello, world")), default(CancellationToken)); + Assert.Same(original, _http1Connection.RequestAborted.WaitHandle); + + await _http1Connection.ProduceEndAsync(); + Assert.NotSame(original, _http1Connection.RequestAborted.WaitHandle); + } + + [Fact] + public async Task ExceptionDetailNotIncludedWhenLogLevelInformationNotEnabled() + { + var previousLog = _serviceContext.Log; + + try + { + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(false); + + _serviceContext.Log = mockTrace.Object; + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET /%00 HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; + + var exception = Assert.Throws(() => + _http1Connection.TakeStartLine(readableBuffer, out _consumed, out _examined)); + _transport.Input.AdvanceTo(_consumed, _examined); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(string.Empty), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + finally + { + _serviceContext.Log = previousLog; + } + } + + [Theory] + [InlineData(1, 1)] + [InlineData(5, 5)] + [InlineData(100, 100)] + [InlineData(600, 100)] + [InlineData(700, 1)] + [InlineData(1, 700)] + public async Task AcceptsHeadersAcrossSends(int header0Count, int header1Count) + { + _serviceContext.ServerOptions.Limits.MaxRequestHeaderCount = header0Count + header1Count; + + var headers0 = MakeHeaders(header0Count); + var headers1 = MakeHeaders(header1Count, header0Count); + + var requestProcessingTask = _http1Connection.ProcessRequestsAsync(null); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders != null); + Assert.Equal(0, _http1Connection.RequestHeaders.Count); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers0)); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders.Count >= header0Count); + Assert.Equal(header0Count, _http1Connection.RequestHeaders.Count); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers1)); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders.Count >= header0Count + header1Count); + Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + await requestProcessingTask.DefaultTimeout(); + } + + [Theory] + [InlineData(1, 1)] + [InlineData(5, 5)] + [InlineData(100, 100)] + [InlineData(600, 100)] + [InlineData(700, 1)] + [InlineData(1, 700)] + public async Task KeepsSameHeaderCollectionAcrossSends(int header0Count, int header1Count) + { + _serviceContext.ServerOptions.Limits.MaxRequestHeaderCount = header0Count + header1Count; + + var headers0 = MakeHeaders(header0Count); + var headers1 = MakeHeaders(header1Count, header0Count); + + var requestProcessingTask = _http1Connection.ProcessRequestsAsync(null); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders != null); + Assert.Equal(0, _http1Connection.RequestHeaders.Count); + + var newRequestHeaders = new RequestHeadersWrapper(_http1Connection.RequestHeaders); + _http1Connection.RequestHeaders = newRequestHeaders; + Assert.Same(newRequestHeaders, _http1Connection.RequestHeaders); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers0)); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders.Count >= header0Count); + Assert.Same(newRequestHeaders, _http1Connection.RequestHeaders); + Assert.Equal(header0Count, _http1Connection.RequestHeaders.Count); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers1)); + await WaitForCondition(TestConstants.DefaultTimeout, () => _http1Connection.RequestHeaders.Count >= header0Count + header1Count); + Assert.Same(newRequestHeaders, _http1Connection.RequestHeaders); + Assert.Equal(header0Count + header1Count, _http1Connection.RequestHeaders.Count); + + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); + } + + [Fact] + public void ThrowsWhenMaxRequestBodySizeIsSetAfterReadingFromRequestBody() + { + // Act + // This would normally be set by the MessageBody during the first read. + _http1Connection.HasStartedConsumingRequestBody = true; + + // Assert + Assert.True(((IHttpMaxRequestBodySizeFeature)_http1Connection).IsReadOnly); + var ex = Assert.Throws(() => ((IHttpMaxRequestBodySizeFeature)_http1Connection).MaxRequestBodySize = 1); + Assert.Equal(CoreStrings.MaxRequestBodySizeCannotBeModifiedAfterRead, ex.Message); + } + + [Fact] + public void ThrowsWhenMaxRequestBodySizeIsSetToANegativeValue() + { + // Assert + var ex = Assert.Throws(() => ((IHttpMaxRequestBodySizeFeature)_http1Connection).MaxRequestBodySize = -1); + Assert.StartsWith(CoreStrings.NonNegativeNumberOrNullRequired, ex.Message); + } + + [Fact] + public async Task ConsumesRequestWhenApplicationDoesNotConsumeIt() + { + var httpApplication = new DummyApplication(async context => + { + var buffer = new byte[10]; + await context.Response.Body.WriteAsync(buffer, 0, 10); + }); + var mockMessageBody = new Mock(null); + _http1Connection.NextMessageBody = mockMessageBody.Object; + + var requestProcessingTask = _http1Connection.ProcessRequestsAsync(httpApplication); + + var data = Encoding.ASCII.GetBytes("POST / HTTP/1.1\r\nHost:\r\nConnection: close\r\ncontent-length: 1\r\n\r\n"); + await _application.Output.WriteAsync(data); + await requestProcessingTask.DefaultTimeout(); + + mockMessageBody.Verify(body => body.ConsumeAsync(), Times.Once); + } + + [Fact] + public void Http10HostHeaderNotRequired() + { + _http1Connection.HttpVersion = "HTTP/1.0"; + _http1Connection.EnsureHostHeaderExists(); + } + + [Fact] + public void Http10HostHeaderAllowed() + { + _http1Connection.HttpVersion = "HTTP/1.0"; + _http1Connection.RequestHeaders[HeaderNames.Host] = "localhost:5000"; + _http1Connection.EnsureHostHeaderExists(); + } + + [Fact] + public void Http11EmptyHostHeaderAccepted() + { + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.RequestHeaders[HeaderNames.Host] = ""; + _http1Connection.EnsureHostHeaderExists(); + } + + [Fact] + public void Http11ValidHostHeadersAccepted() + { + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.RequestHeaders[HeaderNames.Host] = "localhost:5000"; + _http1Connection.EnsureHostHeaderExists(); + } + + [Fact] + public void BadRequestFor10BadHostHeaderFormat() + { + _http1Connection.HttpVersion = "HTTP/1.0"; + _http1Connection.RequestHeaders[HeaderNames.Host] = "a=b"; + var ex = Assert.Throws(() => _http1Connection.EnsureHostHeaderExists()); + Assert.Equal(CoreStrings.FormatBadRequest_InvalidHostHeader_Detail("a=b"), ex.Message); + } + + [Fact] + public void BadRequestFor11BadHostHeaderFormat() + { + _http1Connection.HttpVersion = "HTTP/1.1"; + _http1Connection.RequestHeaders[HeaderNames.Host] = "a=b"; + var ex = Assert.Throws(() => _http1Connection.EnsureHostHeaderExists()); + Assert.Equal(CoreStrings.FormatBadRequest_InvalidHostHeader_Detail("a=b"), ex.Message); + } + + private static async Task WaitForCondition(TimeSpan timeout, Func condition) + { + const int MaxWaitLoop = 150; + + var delay = (int)Math.Ceiling(timeout.TotalMilliseconds / MaxWaitLoop); + + var waitLoop = 0; + while (waitLoop < MaxWaitLoop && !condition()) + { + // Wait for parsing condition to trigger + await Task.Delay(delay); + waitLoop++; + } + } + + private static string MakeHeaders(int count, int startAt = 0) + { + return string.Join("", Enumerable + .Range(0, count) + .Select(i => $"Header-{startAt + i}: value{startAt + i}\r\n")); + } + + public static IEnumerable RequestLineValidData => HttpParsingData.RequestLineValidData; + + public static IEnumerable RequestLineDotSegmentData => HttpParsingData.RequestLineDotSegmentData; + + public static TheoryData TargetWithEncodedNullCharData + { + get + { + var data = new TheoryData(); + + foreach (var target in HttpParsingData.TargetWithEncodedNullCharData) + { + data.Add(target); + } + + return data; + } + } + + public static TheoryData TargetInvalidData + => HttpParsingData.TargetInvalidData; + + public static TheoryData MethodNotAllowedTargetData + => HttpParsingData.MethodNotAllowedRequestLine; + + public static TheoryData TargetWithNullCharData + { + get + { + var data = new TheoryData(); + + foreach (var target in HttpParsingData.TargetWithNullCharData) + { + data.Add(target); + } + + return data; + } + } + + public static TheoryData MethodWithNullCharData + { + get + { + var data = new TheoryData(); + + foreach (var target in HttpParsingData.MethodWithNullCharData) + { + data.Add(target); + } + + return data; + } + } + + public static TheoryData QueryStringWithNullCharData + { + get + { + var data = new TheoryData(); + + foreach (var target in HttpParsingData.QueryStringWithNullCharData) + { + data.Add(target); + } + + return data; + } + } + + public static TheoryData RequestBodyTimeoutDataValid => new TheoryData + { + TimeSpan.FromTicks(1), + TimeSpan.MaxValue, + Timeout.InfiniteTimeSpan, + TimeSpan.FromMilliseconds(-1) // Same as Timeout.InfiniteTimeSpan + }; + + public static TheoryData RequestBodyTimeoutDataInvalid => new TheoryData + { + TimeSpan.MinValue, + TimeSpan.FromTicks(-1), + TimeSpan.Zero + }; + + public static TheoryData MinDataRateData => new TheoryData + { + null, + new MinDataRate(bytesPerSecond: 1, gracePeriod: TimeSpan.MaxValue) + }; + + private class RequestHeadersWrapper : IHeaderDictionary + { + IHeaderDictionary _innerHeaders; + + public RequestHeadersWrapper(IHeaderDictionary headers) + { + _innerHeaders = headers; + } + + public StringValues this[string key] { get => _innerHeaders[key]; set => _innerHeaders[key] = value; } + public long? ContentLength { get => _innerHeaders.ContentLength; set => _innerHeaders.ContentLength = value; } + public ICollection Keys => _innerHeaders.Keys; + public ICollection Values => _innerHeaders.Values; + public int Count => _innerHeaders.Count; + public bool IsReadOnly => _innerHeaders.IsReadOnly; + public void Add(string key, StringValues value) => _innerHeaders.Add(key, value); + public void Add(KeyValuePair item) => _innerHeaders.Add(item); + public void Clear() => _innerHeaders.Clear(); + public bool Contains(KeyValuePair item) => _innerHeaders.Contains(item); + public bool ContainsKey(string key) => _innerHeaders.ContainsKey(key); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => _innerHeaders.CopyTo(array, arrayIndex); + public IEnumerator> GetEnumerator() => _innerHeaders.GetEnumerator(); + public bool Remove(string key) => _innerHeaders.Remove(key); + public bool Remove(KeyValuePair item) => _innerHeaders.Remove(item); + public bool TryGetValue(string key, out StringValues value) => _innerHeaders.TryGetValue(key, out value); + IEnumerator IEnumerable.GetEnumerator() => _innerHeaders.GetEnumerator(); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/Http2ConnectionTests.cs b/src/Servers/Kestrel/Core/test/Http2ConnectionTests.cs new file mode 100644 index 0000000000..f8038dffff --- /dev/null +++ b/src/Servers/Kestrel/Core/test/Http2ConnectionTests.cs @@ -0,0 +1,3059 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Net.Http.Headers; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class Http2ConnectionTests : IDisposable, IHttpHeadersHandler + { + private static readonly string _largeHeaderValue = new string('a', HPackDecoder.MaxStringOctets); + + private static readonly IEnumerable> _postRequestHeaders = new[] + { + new KeyValuePair(":method", "POST"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "localhost:80"), + }; + + private static readonly IEnumerable> _expectContinueRequestHeaders = new[] + { + new KeyValuePair(":method", "POST"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("expect", "100-continue"), + }; + + private static readonly IEnumerable> _browserRequestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "localhost:80"), + new KeyValuePair("user-agent", "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:54.0) Gecko/20100101 Firefox/54.0"), + new KeyValuePair("accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), + new KeyValuePair("accept-language", "en-US,en;q=0.5"), + new KeyValuePair("accept-encoding", "gzip, deflate, br"), + new KeyValuePair("upgrade-insecure-requests", "1"), + }; + + private static readonly IEnumerable> _requestTrailers = new[] + { + new KeyValuePair("trailer-one", "1"), + new KeyValuePair("trailer-two", "2"), + }; + + private static readonly IEnumerable> _oneContinuationRequestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "localhost:80"), + new KeyValuePair("a", _largeHeaderValue), + new KeyValuePair("b", _largeHeaderValue), + new KeyValuePair("c", _largeHeaderValue), + new KeyValuePair("d", _largeHeaderValue) + }; + + private static readonly IEnumerable> _twoContinuationsRequestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "localhost:80"), + new KeyValuePair("a", _largeHeaderValue), + new KeyValuePair("b", _largeHeaderValue), + new KeyValuePair("c", _largeHeaderValue), + new KeyValuePair("d", _largeHeaderValue), + new KeyValuePair("e", _largeHeaderValue), + new KeyValuePair("f", _largeHeaderValue), + new KeyValuePair("g", _largeHeaderValue), + new KeyValuePair("h", _largeHeaderValue) + }; + + private static readonly byte[] _helloBytes = Encoding.ASCII.GetBytes("hello"); + private static readonly byte[] _worldBytes = Encoding.ASCII.GetBytes("world"); + private static readonly byte[] _helloWorldBytes = Encoding.ASCII.GetBytes("hello, world"); + private static readonly byte[] _noData = new byte[0]; + private static readonly byte[] _maxData = Encoding.ASCII.GetBytes(new string('a', Http2Frame.MinAllowedMaxFrameSize)); + + private readonly MemoryPool _memoryPool = KestrelMemoryPool.Create(); + private readonly DuplexPipe.DuplexPipePair _pair; + private readonly TestApplicationErrorLogger _logger; + private readonly Http2ConnectionContext _connectionContext; + private readonly Http2Connection _connection; + private readonly Http2PeerSettings _clientSettings = new Http2PeerSettings(); + private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); + private readonly HPackDecoder _hpackDecoder; + + private readonly ConcurrentDictionary> _runningStreams = new ConcurrentDictionary>(); + private readonly Dictionary _receivedHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + private readonly Dictionary _decodedHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + private readonly HashSet _abortedStreamIds = new HashSet(); + private readonly object _abortedStreamIdsLock = new object(); + + private readonly RequestDelegate _noopApplication; + private readonly RequestDelegate _echoHost; + private readonly RequestDelegate _readHeadersApplication; + private readonly RequestDelegate _readTrailersApplication; + private readonly RequestDelegate _bufferingApplication; + private readonly RequestDelegate _echoApplication; + private readonly RequestDelegate _echoWaitForAbortApplication; + private readonly RequestDelegate _largeHeadersApplication; + private readonly RequestDelegate _waitForAbortApplication; + private readonly RequestDelegate _waitForAbortFlushingApplication; + + private Task _connectionTask; + + public Http2ConnectionTests() + { + var inlineSchedulingPipeOptions = new PipeOptions( + pool: _memoryPool, + readerScheduler: PipeScheduler.Inline, + writerScheduler: PipeScheduler.Inline, + useSynchronizationContext: false + ); + + _pair = DuplexPipe.CreateConnectionPair(inlineSchedulingPipeOptions, inlineSchedulingPipeOptions); + + _noopApplication = context => Task.CompletedTask; + + _echoHost = context => + { + context.Response.Headers[HeaderNames.Host] = context.Request.Headers[HeaderNames.Host]; + + return Task.CompletedTask; + }; + + _readHeadersApplication = context => + { + foreach (var header in context.Request.Headers) + { + _receivedHeaders[header.Key] = header.Value.ToString(); + } + + return Task.CompletedTask; + }; + + _readTrailersApplication = async context => + { + using (var ms = new MemoryStream()) + { + // Consuming the entire request body guarantees trailers will be available + await context.Request.Body.CopyToAsync(ms); + } + + foreach (var header in context.Request.Headers) + { + _receivedHeaders[header.Key] = header.Value.ToString(); + } + }; + + _bufferingApplication = async context => + { + var data = new List(); + var buffer = new byte[1024]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + data.AddRange(new ArraySegment(buffer, 0, received)); + } + + await context.Response.Body.WriteAsync(data.ToArray(), 0, data.Count); + }; + + _echoApplication = async context => + { + var buffer = new byte[Http2Frame.MinAllowedMaxFrameSize]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + await context.Response.Body.WriteAsync(buffer, 0, received); + } + }; + + _echoWaitForAbortApplication = async context => + { + var buffer = new byte[Http2Frame.MinAllowedMaxFrameSize]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + await context.Response.Body.WriteAsync(buffer, 0, received); + } + + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + sem.Release(); + }); + + await sem.WaitAsync().DefaultTimeout(); + }; + + _largeHeadersApplication = context => + { + foreach (var name in new[] { "a", "b", "c", "d", "e", "f", "g", "h" }) + { + context.Response.Headers[name] = _largeHeaderValue; + } + + return Task.CompletedTask; + }; + + _waitForAbortApplication = async context => + { + var streamIdFeature = context.Features.Get(); + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + sem.Release(); + }); + + await sem.WaitAsync().DefaultTimeout(); + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }; + + _waitForAbortFlushingApplication = async context => + { + var streamIdFeature = context.Features.Get(); + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + sem.Release(); + }); + + await sem.WaitAsync().DefaultTimeout(); + + await context.Response.Body.FlushAsync(); + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }; + + _hpackDecoder = new HPackDecoder((int)_clientSettings.HeaderTableSize); + + _logger = new TestApplicationErrorLogger(); + + _connectionContext = new Http2ConnectionContext + { + ServiceContext = new TestServiceContext() + { + Log = new TestKestrelTrace(_logger) + }, + MemoryPool = _memoryPool, + Application = _pair.Application, + Transport = _pair.Transport + }; + _connection = new Http2Connection(_connectionContext); + } + + public void Dispose() + { + _memoryPool.Dispose(); + } + + void IHttpHeadersHandler.OnHeader(Span name, Span value) + { + _decodedHeaders[name.GetAsciiStringNonNullCharacters()] = value.GetAsciiStringNonNullCharacters(); + } + + [Fact] + public async Task DATA_Received_ReadByStream() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_MaxSize_ReadByStream() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendDataAsync(1, _maxData, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: _maxData.Length, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _maxData); + } + + [Fact] + public async Task DATA_Received_Multiple_ReadByStream() + { + await InitializeConnectionAsync(_bufferingApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + + for (var i = 0; i < _helloWorldBytes.Length; i++) + { + await SendDataAsync(1, new ArraySegment(_helloWorldBytes, i, 1), endStream: false); + } + + await SendDataAsync(1, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_Multiplexed_ReadByStreams() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await StartStreamAsync(3, _browserRequestHeaders, endStream: false); + + await SendDataAsync(1, _helloBytes, endStream: false); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var stream1DataFrame1 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + await SendDataAsync(3, _helloBytes, endStream: false); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + var stream3DataFrame1 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + await SendDataAsync(3, _worldBytes, endStream: false); + + var stream3DataFrame2 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + await SendDataAsync(1, _worldBytes, endStream: false); + + var stream1DataFrame2 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + await SendDataAsync(1, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await SendDataAsync(3, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + + Assert.Equal(stream1DataFrame1.DataPayload, _helloBytes); + Assert.Equal(stream1DataFrame2.DataPayload, _worldBytes); + Assert.Equal(stream3DataFrame1.DataPayload, _helloBytes); + Assert.Equal(stream3DataFrame2.DataPayload, _worldBytes); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task DATA_Received_WithPadding_ReadByStream(byte padLength) + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendDataWithPaddingAsync(1, _helloWorldBytes, padLength, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendDataAsync(0, _noData, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdZero(Http2FrameType.DATA)); + } + + [Fact] + public async Task DATA_Received_StreamIdEven_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendDataAsync(2, _noData, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdEven(Http2FrameType.DATA, streamId: 2)); + } + + [Fact] + public async Task DATA_Received_PaddingEqualToFramePayloadLength_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 5, padLength: 5); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorPaddingTooLong(Http2FrameType.DATA)); + } + + [Fact] + public async Task DATA_Received_PaddingGreaterThanFramePayloadLength_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 5, padLength: 6); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorPaddingTooLong(Http2FrameType.DATA)); + } + + [Fact] + public async Task DATA_Received_FrameLengthZeroPaddingZero_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 0, padLength: 0); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorPaddingTooLong(Http2FrameType.DATA)); + } + + [Fact] + public async Task DATA_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.DATA, streamId: 1, headersStreamId: 1)); + } + + [Fact] + public async Task DATA_Received_StreamIdle_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendDataAsync(1, _helloWorldBytes, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdle(Http2FrameType.DATA, streamId: 1)); + } + + [Fact] + public async Task DATA_Received_StreamHalfClosedRemote_ConnectionError() + { + // Use _waitForAbortApplication so we know the stream will still be active when we send the illegal DATA frame + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _postRequestHeaders, endStream: true); + + await SendDataAsync(1, _helloWorldBytes, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamHalfClosedRemote(Http2FrameType.DATA, streamId: 1)); + } + + [Fact] + public async Task DATA_Received_StreamClosed_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, _postRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await SendDataAsync(1, _helloWorldBytes, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.DATA, streamId: 1)); + } + + [Fact] + public async Task DATA_Received_StreamClosedImplicitly_ConnectionError() + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1 + // + // The first use of a new stream identifier implicitly closes all streams in the "idle" state that + // might have been initiated by that peer with a lower-valued stream identifier. For example, if a + // client sends a HEADERS frame on stream 7 without ever sending a frame on stream 5, then stream 5 + // transitions to the "closed" state when the first frame for stream 7 is sent or received. + + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 3, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.DATA, streamId: 1)); + } + + [Fact] + public async Task HEADERS_Received_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_WithPadding_Decoded(byte padLength) + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPaddingAsync(1, _browserRequestHeaders, padLength, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_WithPriority_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPriorityAsync(1, _browserRequestHeaders, priority: 42, streamDependency: 0, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_WithPriorityAndPadding_Decoded(byte padLength) + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPaddingAndPriorityAsync(1, _browserRequestHeaders, padLength, priority: 42, streamDependency: 0, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task HEADERS_Received_WithTrailers_Decoded(bool sendData) + { + await InitializeConnectionAsync(_readTrailersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + + // Initialize another stream with a higher stream ID, and verify that after trailers are + // decoded by the other stream, the highest opened stream ID is not reset to the lower ID + // (the highest opened stream ID is sent by the server in the GOAWAY frame when shutting + // down the connection). + await SendHeadersAsync(3, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, _browserRequestHeaders); + + // The second stream should end first, since the first one is waiting for the request body. + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + if (sendData) + { + await SendDataAsync(1, _helloBytes, endStream: false); + } + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, _requestTrailers); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders.Concat(_requestTrailers)); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_ContainsExpect100Continue_100ContinueSent() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _expectContinueRequestHeaders, false); + + var frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 5, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await SendDataAsync(1, _helloBytes, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + Assert.Equal(new byte[] { 0x08, 0x03, (byte)'1', (byte)'0', (byte)'0' }, frame.HeadersPayload.ToArray()); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(0, _browserRequestHeaders, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdZero(Http2FrameType.HEADERS)); + } + + [Fact] + public async Task HEADERS_Received_StreamIdEven_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(2, _browserRequestHeaders, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdEven(Http2FrameType.HEADERS, streamId: 2)); + } + + [Fact] + public async Task HEADERS_Received_StreamClosed_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + // Try to re-use the stream ID (http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1) + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.HEADERS, streamId: 1)); + } + + [Fact] + public async Task HEADERS_Received_StreamHalfClosedRemote_ConnectionError() + { + // Use _waitForAbortApplication so we know the stream will still be active when we send the illegal DATA frame + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamHalfClosedRemote(Http2FrameType.HEADERS, streamId: 1)); + } + + [Fact] + public async Task HEADERS_Received_StreamClosedImplicitly_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + // Stream 1 was implicitly closed by opening stream 3 before (http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1) + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 3, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.HEADERS, streamId: 1)); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_PaddingEqualToFramePayloadLength_ConnectionError(byte padLength) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidHeadersFrameAsync(1, frameLength: padLength, padLength: padLength); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorPaddingTooLong(Http2FrameType.HEADERS)); + } + + [Theory] + [InlineData(0, 1)] + [InlineData(1, 2)] + [InlineData(254, 255)] + public async Task HEADERS_Received_PaddingGreaterThanFramePayloadLength_ConnectionError(int frameLength, byte padLength) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidHeadersFrameAsync(1, frameLength, padLength); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorPaddingTooLong(Http2FrameType.HEADERS)); + } + + [Fact] + public async Task HEADERS_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendHeadersAsync(3, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.HEADERS, streamId: 3, headersStreamId: 1)); + } + + [Fact] + public async Task HEADERS_Received_WithPriority_StreamDependencyOnSelf_ConnectionError() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPriorityAsync(1, _browserRequestHeaders, priority: 42, streamDependency: 1, endStream: true); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamSelfDependency(Http2FrameType.HEADERS, streamId: 1)); + } + + [Fact] + public async Task HEADERS_Received_IncompleteHeaderBlock_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendIncompleteHeadersFrameAsync(streamId: 1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.COMPRESSION_ERROR, + expectedErrorMessage: CoreStrings.HPackErrorIncompleteHeaderBlock); + } + + [Theory] + [MemberData(nameof(IllegalTrailerData))] + public async Task HEADERS_Received_WithTrailers_ContainsIllegalTrailer_ConnectionError(byte[] trailers, string expectedErrorMessage) + { + await InitializeConnectionAsync(_readTrailersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, trailers); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: expectedErrorMessage); + } + + [Theory] + [InlineData(Http2HeadersFrameFlags.NONE)] + [InlineData(Http2HeadersFrameFlags.END_HEADERS)] + public async Task HEADERS_Received_WithTrailers_EndStreamNotSet_ConnectionError(Http2HeadersFrameFlags flags) + { + await InitializeConnectionAsync(_readTrailersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + await SendHeadersAsync(1, flags, _requestTrailers); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorHeadersWithTrailersNoEndStream); + } + + [Theory] + [MemberData(nameof(UpperCaseHeaderNameData))] + public async Task HEADERS_Received_HeaderNameContainsUpperCaseCharacter_StreamError(byte[] headerBlock) + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headerBlock); + await WaitForStreamErrorAsync( + ignoreNonRstStreamFrames: false, + expectedStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorHeaderNameUppercase); + + // Verify that the stream ID can't be re-used + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.HEADERS, streamId: 1)); + } + + [Fact] + public Task HEADERS_Received_HeaderBlockContainsUnknownPseudoHeaderField_StreamError() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":unknown", "0"), + }; + + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, expectedErrorMessage: CoreStrings.Http2ErrorUnknownPseudoHeaderField); + } + + [Fact] + public Task HEADERS_Received_HeaderBlockContainsResponsePseudoHeaderField_StreamError() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":status", "200"), + }; + + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, expectedErrorMessage: CoreStrings.Http2ErrorResponsePseudoHeaderField); + } + + [Theory] + [MemberData(nameof(DuplicatePseudoHeaderFieldData))] + public Task HEADERS_Received_HeaderBlockContainsDuplicatePseudoHeaderField_StreamError(IEnumerable> headers) + { + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, expectedErrorMessage: CoreStrings.Http2ErrorDuplicatePseudoHeaderField); + } + + [Theory] + [MemberData(nameof(MissingPseudoHeaderFieldData))] + public Task HEADERS_Received_HeaderBlockDoesNotContainMandatoryPseudoHeaderField_StreamError(IEnumerable> headers) + { + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, expectedErrorMessage: CoreStrings.Http2ErrorMissingMandatoryPseudoHeaderFields); + } + + [Theory] + [MemberData(nameof(ConnectMissingPseudoHeaderFieldData))] + public async Task HEADERS_Received_HeaderBlockDoesNotContainMandatoryPseudoHeaderField_MethodIsCONNECT_NoError(IEnumerable> headers) + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, headers); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [MemberData(nameof(PseudoHeaderFieldAfterRegularHeadersData))] + public Task HEADERS_Received_HeaderBlockContainsPseudoHeaderFieldAfterRegularHeaders_StreamError(IEnumerable> headers) + { + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, expectedErrorMessage: CoreStrings.Http2ErrorPseudoHeaderFieldAfterRegularHeaders); + } + + private async Task HEADERS_Received_InvalidHeaderFields_StreamError(IEnumerable> headers, string expectedErrorMessage) + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers); + await WaitForStreamErrorAsync( + ignoreNonRstStreamFrames: false, + expectedStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: expectedErrorMessage); + + // Verify that the stream ID can't be re-used + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.HEADERS, streamId: 1)); + } + + [Fact] + public Task HEADERS_Received_HeaderBlockContainsConnectionHeader_StreamError() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("connection", "keep-alive") + }; + + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, CoreStrings.Http2ErrorConnectionSpecificHeaderField); + } + + [Fact] + public Task HEADERS_Received_HeaderBlockContainsTEHeader_ValueIsNotTrailers_StreamError() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("te", "trailers, deflate") + }; + + return HEADERS_Received_InvalidHeaderFields_StreamError(headers, CoreStrings.Http2ErrorConnectionSpecificHeaderField); + } + + [Fact] + public async Task HEADERS_Received_HeaderBlockContainsTEHeader_ValueIsTrailers_NoError() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("te", "trailers") + }; + + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, headers); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_InvalidAuthority_400Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "local=host:80"), + }; + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("400", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders["content-length"]); + } + + [Fact] + public async Task HEADERS_Received_MissingAuthority_400Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + }; + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("400", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders["content-length"]); + } + + [Fact] + public async Task HEADERS_Received_TwoHosts_400Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("Host", "host1"), + new KeyValuePair("Host", "host2"), + }; + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("400", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders["content-length"]); + } + + [Fact] + public async Task HEADERS_Received_EmptyAuthority_200Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", ""), + }; + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders["content-length"]); + } + + [Fact] + public async Task HEADERS_Received_EmptyAuthorityOverridesHost_200Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", ""), + new KeyValuePair("Host", "abc"), + }; + await InitializeConnectionAsync(_echoHost); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 62, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(4, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + Assert.Equal("", _decodedHeaders[HeaderNames.Host]); + } + + [Fact] + public async Task HEADERS_Received_AuthorityOverridesHost_200Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "def"), + new KeyValuePair("Host", "abc"), + }; + await InitializeConnectionAsync(_echoHost); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 65, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(4, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + Assert.Equal("def", _decodedHeaders[HeaderNames.Host]); + } + + [Fact] + public async Task HEADERS_Received_MissingAuthorityFallsBackToHost_200Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("Host", "abc"), + }; + await InitializeConnectionAsync(_echoHost); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 65, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(4, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + Assert.Equal("abc", _decodedHeaders[HeaderNames.Host]); + } + + [Fact] + public async Task HEADERS_Received_AuthorityOverridesInvalidHost_200Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "def"), + new KeyValuePair("Host", "a=bc"), + }; + await InitializeConnectionAsync(_echoHost); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 65, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(4, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + Assert.Equal("def", _decodedHeaders[HeaderNames.Host]); + } + + [Fact] + public async Task HEADERS_Received_InvalidAuthorityWithValidHost_400Status() + { + var headers = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "d=ef"), + new KeyValuePair("Host", "abc"), + }; + await InitializeConnectionAsync(_echoHost); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("400", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task PRIORITY_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPriorityAsync(0); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdZero(Http2FrameType.PRIORITY)); + } + + [Fact] + public async Task PRIORITY_Received_StreamIdEven_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPriorityAsync(2); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdEven(Http2FrameType.PRIORITY, streamId: 2)); + } + + [Theory] + [InlineData(4)] + [InlineData(6)] + public async Task PRIORITY_Received_LengthNotFive_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidPriorityFrameAsync(1, length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(Http2FrameType.PRIORITY, expectedLength: 5)); + } + + [Fact] + public async Task PRIORITY_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendPriorityAsync(1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.PRIORITY, streamId: 1, headersStreamId: 1)); + } + + [Fact] + public async Task PRIORITY_Received_StreamDependencyOnSelf_ConnectionError() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendPriorityAsync(1, streamDependency: 1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamSelfDependency(Http2FrameType.PRIORITY, 1)); + } + + [Fact] + public async Task RST_STREAM_Received_AbortsStream() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendRstStreamAsync(1); + + // No data is received from the stream since it was aborted before writing anything + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + } + + [Fact] + public async Task RST_STREAM_Received_AbortsStream_FlushedDataIsSent() + { + await InitializeConnectionAsync(_waitForAbortFlushingApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendRstStreamAsync(1); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + // No END_STREAM DATA frame is received since the stream was aborted + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Contains(1, _abortedStreamIds); + } + + [Fact] + public async Task RST_STREAM_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendRstStreamAsync(0); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdZero(Http2FrameType.RST_STREAM)); + } + + [Fact] + public async Task RST_STREAM_Received_StreamIdEven_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendRstStreamAsync(2); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdEven(Http2FrameType.RST_STREAM, streamId: 2)); + } + + [Fact] + public async Task RST_STREAM_Received_StreamIdle_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendRstStreamAsync(1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdle(Http2FrameType.RST_STREAM, streamId: 1)); + } + + [Theory] + [InlineData(3)] + [InlineData(5)] + public async Task RST_STREAM_Received_LengthNotFour_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + // Start stream 1 so it's legal to send it RST_STREAM frames + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await SendInvalidRstStreamFrameAsync(1, length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: true, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(Http2FrameType.RST_STREAM, expectedLength: 4)); + } + + [Fact] + public async Task RST_STREAM_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendRstStreamAsync(1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.RST_STREAM, streamId: 1, headersStreamId: 1)); + } + + [Fact] + public async Task SETTINGS_Received_Sends_ACK() + { + await InitializeConnectionAsync(_noopApplication); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task SETTINGS_Received_StreamIdNotZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidStreamIdAsync(1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdNotZero(Http2FrameType.SETTINGS)); + } + + [Theory] + [InlineData(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, 2, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, uint.MaxValue, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, (uint)int.MaxValue + 1, Http2ErrorCode.FLOW_CONTROL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, uint.MaxValue, Http2ErrorCode.FLOW_CONTROL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 0, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 1, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 16 * 1024 - 1, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 16 * 1024 * 1024, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, uint.MaxValue, Http2ErrorCode.PROTOCOL_ERROR)] + public async Task SETTINGS_Received_InvalidParameterValue_ConnectionError(Http2SettingsParameter parameter, uint value, Http2ErrorCode expectedErrorCode) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidParameterValueAsync(parameter, value); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: expectedErrorCode, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorSettingsParameterOutOfRange(parameter)); + } + + [Fact] + public async Task SETTINGS_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendSettingsAsync(); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.SETTINGS, streamId: 0, headersStreamId: 1)); + } + + [Theory] + [InlineData(1)] + [InlineData(16 * 1024 - 9)] // Min. max. frame size minus header length + public async Task SETTINGS_Received_WithACK_LengthNotZero_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsAckWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorSettingsAckLengthNotZero); + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + [InlineData(7)] + [InlineData(34)] + [InlineData(37)] + public async Task SETTINGS_Received_LengthNotMultipleOfSix_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorSettingsLengthNotMultipleOfSix); + } + + [Fact] + public async Task PUSH_PROMISE_Received_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPushPromiseFrameAsync(); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorPushPromiseReceived); + } + + [Fact] + public async Task PING_Received_SendsACK() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PING_Received_WithACK_DoesNotSendACK() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingAsync(Http2PingFrameFlags.ACK); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PING_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendPingAsync(Http2PingFrameFlags.NONE); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.PING, streamId: 0, headersStreamId: 1)); + } + + [Fact] + public async Task PING_Received_StreamIdNotZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingWithInvalidStreamIdAsync(streamId: 1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdNotZero(Http2FrameType.PING)); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(7)] + [InlineData(9)] + public async Task PING_Received_LengthNotEight_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(Http2FrameType.PING, expectedLength: 8)); + } + + [Fact] + public async Task GOAWAY_Received_ConnectionStops() + { + await InitializeConnectionAsync(_noopApplication); + + await SendGoAwayAsync(); + + await WaitForConnectionStopAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task GOAWAY_Received_AbortsAllStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + // Start some streams + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + await StartStreamAsync(5, _browserRequestHeaders, endStream: true); + + await SendGoAwayAsync(); + + await WaitForConnectionStopAsync(expectedLastStreamId: 5, ignoreNonGoAwayFrames: true); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + Assert.Contains(3, _abortedStreamIds); + Assert.Contains(5, _abortedStreamIds); + } + + [Fact] + public async Task GOAWAY_Received_StreamIdNotZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidGoAwayFrameAsync(); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdNotZero(Http2FrameType.GOAWAY)); + } + + [Fact] + public async Task GOAWAY_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendGoAwayAsync(); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.GOAWAY, streamId: 0, headersStreamId: 1)); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_StreamIdEven_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendWindowUpdateAsync(2, sizeIncrement: 42); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdEven(Http2FrameType.WINDOW_UPDATE, streamId: 2)); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendWindowUpdateAsync(1, sizeIncrement: 42); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.WINDOW_UPDATE, streamId: 1, headersStreamId: 1)); + } + + [Theory] + [InlineData(0, 3)] + [InlineData(0, 5)] + [InlineData(1, 3)] + [InlineData(1, 5)] + public async Task WINDOW_UPDATE_Received_LengthNotFour_ConnectionError(int streamId, int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidWindowUpdateAsync(streamId, sizeIncrement: 42, length: length); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorUnexpectedFrameLength(Http2FrameType.WINDOW_UPDATE, expectedLength: 4)); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_OnConnection_SizeIncrementZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendWindowUpdateAsync(0, sizeIncrement: 0); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorWindowUpdateIncrementZero); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_OnStream_SizeIncrementZero_ConnectionError() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendWindowUpdateAsync(1, sizeIncrement: 0); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorWindowUpdateIncrementZero); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_StreamIdle_ConnectionError() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await SendWindowUpdateAsync(1, sizeIncrement: 1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdle(Http2FrameType.WINDOW_UPDATE, streamId: 1)); + } + + [Fact] + public async Task CONTINUATION_Received_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await StartStreamAsync(1, _twoContinuationsRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_twoContinuationsRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CONTINUATION_Received_WithTrailers_Decoded(bool sendData) + { + await InitializeConnectionAsync(_readTrailersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + + // Initialize another stream with a higher stream ID, and verify that after trailers are + // decoded by the other stream, the highest opened stream ID is not reset to the lower ID + // (the highest opened stream ID is sent by the server in the GOAWAY frame when shutting + // down the connection). + await SendHeadersAsync(3, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, _browserRequestHeaders); + + // The second stream should end first, since the first one is waiting for the request body. + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + if (sendData) + { + await SendDataAsync(1, _helloBytes, endStream: false); + } + + // Trailers encoded as Literal Header Field without Indexing - New Name + // trailer-1: 1 + // trailer-2: 2 + var trailers = new byte[] { 0x00, 0x09 } + .Concat(Encoding.ASCII.GetBytes("trailer-1")) + .Concat(new byte[] { 0x01, (byte)'1' }) + .Concat(new byte[] { 0x00, 0x09 }) + .Concat(Encoding.ASCII.GetBytes("trailer-2")) + .Concat(new byte[] { 0x01, (byte)'2' }) + .ToArray(); + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_STREAM, new byte[0]); + await SendContinuationAsync(1, Http2ContinuationFrameFlags.END_HEADERS, trailers); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders.Concat(new[] + { + new KeyValuePair("trailer-1", "1"), + new KeyValuePair("trailer-2", "2") + })); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task CONTINUATION_Received_StreamIdMismatch_ConnectionError() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _oneContinuationRequestHeaders); + await SendContinuationAsync(3, Http2ContinuationFrameFlags.END_HEADERS); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.CONTINUATION, streamId: 3, headersStreamId: 1)); + } + + [Fact] + public async Task CONTINUATION_Received_IncompleteHeaderBlock_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _postRequestHeaders); + await SendIncompleteContinuationFrameAsync(streamId: 1); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.COMPRESSION_ERROR, + expectedErrorMessage: CoreStrings.HPackErrorIncompleteHeaderBlock); + } + + [Theory] + [MemberData(nameof(IllegalTrailerData))] + public async Task CONTINUATION_Received_WithTrailers_ContainsIllegalTrailer_ConnectionError(byte[] trailers, string expectedErrorMessage) + { + await InitializeConnectionAsync(_readTrailersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, _browserRequestHeaders); + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_STREAM, new byte[0]); + await SendContinuationAsync(1, Http2ContinuationFrameFlags.END_HEADERS, trailers); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: expectedErrorMessage); + } + + [Theory] + [MemberData(nameof(MissingPseudoHeaderFieldData))] + public async Task CONTINUATION_Received_HeaderBlockDoesNotContainMandatoryPseudoHeaderField_StreamError(IEnumerable> headers) + { + await InitializeConnectionAsync(_noopApplication); + + Assert.True(await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, headers)); + await SendEmptyContinuationFrameAsync(1, Http2ContinuationFrameFlags.END_HEADERS); + + await WaitForStreamErrorAsync( + ignoreNonRstStreamFrames: false, + expectedStreamId: 1, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.Http2ErrorMissingMandatoryPseudoHeaderFields); + + // Verify that the stream ID can't be re-used + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers); + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 1, + expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamClosed(Http2FrameType.HEADERS, streamId: 1)); + } + + [Theory] + [MemberData(nameof(ConnectMissingPseudoHeaderFieldData))] + public async Task CONTINUATION_Received_HeaderBlockDoesNotContainMandatoryPseudoHeaderField_MethodIsCONNECT_NoError(IEnumerable> headers) + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_STREAM, headers); + await SendEmptyContinuationFrameAsync(1, Http2ContinuationFrameFlags.END_HEADERS); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task CONTINUATION_Sent_WhenHeadersLargerThanFrameLength() + { + await InitializeConnectionAsync(_largeHeadersApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 12361, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var continuationFrame1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 12306, + withFlags: (byte)Http2ContinuationFrameFlags.NONE, + withStreamId: 1); + var continuationFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 8204, + withFlags: (byte)Http2ContinuationFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + _hpackDecoder.Decode(continuationFrame1.HeadersPayload, endHeaders: false, handler: this); + _hpackDecoder.Decode(continuationFrame2.HeadersPayload, endHeaders: true, handler: this); + + Assert.Equal(11, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[":status"]); + Assert.Equal("0", _decodedHeaders["content-length"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["a"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["b"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["c"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["d"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["e"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["f"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["g"]); + Assert.Equal(_largeHeaderValue, _decodedHeaders["h"]); + } + + [Fact] + public async Task UnknownFrameType_Received_Ignored() + { + await InitializeConnectionAsync(_noopApplication); + + await SendUnknownFrameTypeAsync(streamId: 1, frameType: 42); + + // Check that the connection is still alive + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + + await StopConnectionAsync(0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task UnknownFrameType_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendUnknownFrameTypeAsync(streamId: 1, frameType: 42); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 0, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(frameType: 42, streamId: 1, headersStreamId: 1)); + } + + [Fact] + public async Task ConnectionErrorAbortsAllStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + // Start some streams + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + await StartStreamAsync(5, _browserRequestHeaders, endStream: true); + + // Cause a connection error by sending an invalid frame + await SendDataAsync(0, _noData, endStream: false); + + await WaitForConnectionErrorAsync( + ignoreNonGoAwayFrames: false, + expectedLastStreamId: 5, + expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, + expectedErrorMessage: CoreStrings.FormatHttp2ErrorStreamIdZero(Http2FrameType.DATA)); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + Assert.Contains(3, _abortedStreamIds); + Assert.Contains(5, _abortedStreamIds); + } + + [Fact] + public async Task ConnectionResetLoggedWithActiveStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM, _browserRequestHeaders); + + _pair.Application.Output.Complete(new ConnectionResetException(string.Empty)); + + var result = await _pair.Application.Input.ReadAsync(); + Assert.True(result.IsCompleted); + Assert.Single(_logger.Messages, m => m.Exception is ConnectionResetException); + } + + [Fact] + public async Task ConnectionResetNotLoggedWithNoActiveStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + _pair.Application.Output.Complete(new ConnectionResetException(string.Empty)); + + var result = await _pair.Application.Input.ReadAsync(); + Assert.True(result.IsCompleted); + Assert.DoesNotContain(_logger.Messages, m => m.Exception is ConnectionResetException); + } + + private async Task InitializeConnectionAsync(RequestDelegate application) + { + _connectionTask = _connection.ProcessRequestsAsync(new DummyApplication(application)); + + await SendPreambleAsync().ConfigureAwait(false); + await SendSettingsAsync(); + + await ExpectAsync(Http2FrameType.SETTINGS, + withLength: 0, + withFlags: 0, + withStreamId: 0); + + await ExpectAsync(Http2FrameType.SETTINGS, + withLength: 0, + withFlags: (byte)Http2SettingsFrameFlags.ACK, + withStreamId: 0); + } + + private async Task StartStreamAsync(int streamId, IEnumerable> headers, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.NONE, streamId); + var done = _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + frame.Length = length; + + if (done) + { + frame.HeadersFlags = Http2HeadersFrameFlags.END_HEADERS; + } + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + + while (!done) + { + frame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + done = _hpackEncoder.Encode(frame.HeadersPayload, out length); + frame.Length = length; + + if (done) + { + frame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; + } + + await SendAsync(frame.Raw); + } + } + + private async Task SendHeadersWithPaddingAsync(int streamId, IEnumerable> headers, byte padLength, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PADDED, streamId); + frame.HeadersPadLength = padLength; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 1 + length + padLength; + frame.Payload.Slice(1 + length).Fill(0); + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private async Task SendHeadersWithPriorityAsync(int streamId, IEnumerable> headers, byte priority, int streamDependency, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PRIORITY, streamId); + frame.HeadersPriority = priority; + frame.HeadersStreamDependency = streamDependency; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 5 + length; + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private async Task SendHeadersWithPaddingAndPriorityAsync(int streamId, IEnumerable> headers, byte padLength, byte priority, int streamDependency, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PADDED | Http2HeadersFrameFlags.PRIORITY, streamId); + frame.HeadersPadLength = padLength; + frame.HeadersPriority = priority; + frame.HeadersStreamDependency = streamDependency; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 6 + length + padLength; + frame.Payload.Slice(6 + length).Fill(0); + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private Task SendStreamDataAsync(int streamId, Span data) + { + var tasks = new List(); + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + + while (data.Length > frame.Length) + { + data.Slice(0, frame.Length).CopyTo(frame.Payload); + data = data.Slice(frame.Length); + tasks.Add(SendAsync(frame.Raw)); + } + + frame.Length = data.Length; + frame.DataFlags = Http2DataFrameFlags.END_STREAM; + data.CopyTo(frame.Payload); + tasks.Add(SendAsync(frame.Raw)); + + return Task.WhenAll(tasks); + } + + private Task WaitForAllStreamsAsync() + { + return Task.WhenAll(_runningStreams.Values.Select(tcs => tcs.Task)).DefaultTimeout(); + } + + private Task SendAsync(ReadOnlySpan span) + { + var writableBuffer = _pair.Application.Output; + writableBuffer.Write(span); + return FlushAsync(writableBuffer); + } + + private static async Task FlushAsync(PipeWriter writableBuffer) + { + await writableBuffer.FlushAsync(); + } + + private Task SendPreambleAsync() => SendAsync(new ArraySegment(Http2Connection.ClientPreface)); + + private Task SendSettingsAsync() + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + return SendAsync(frame.Raw); + } + + private Task SendSettingsAckWithInvalidLengthAsync(int length) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.ACK); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidStreamIdAsync(int streamId) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + frame.StreamId = streamId; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidLengthAsync(int length) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidParameterValueAsync(Http2SettingsParameter parameter, uint value) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE); + frame.Length = 6; + + frame.Payload[0] = (byte)((ushort)parameter >> 8); + frame.Payload[1] = (byte)(ushort)parameter; + frame.Payload[2] = (byte)(value >> 24); + frame.Payload[3] = (byte)(value >> 16); + frame.Payload[4] = (byte)(value >> 8); + frame.Payload[5] = (byte)value; + + return SendAsync(frame.Raw); + } + + private Task SendPushPromiseFrameAsync() + { + var frame = new Http2Frame(); + frame.Length = 0; + frame.Type = Http2FrameType.PUSH_PROMISE; + frame.StreamId = 1; + return SendAsync(frame.Raw); + } + + private async Task SendHeadersAsync(int streamId, Http2HeadersFrameFlags flags, IEnumerable> headers) + { + var frame = new Http2Frame(); + + frame.PrepareHeaders(flags, streamId); + var done = _hpackEncoder.BeginEncode(headers, frame.Payload, out var length); + frame.Length = length; + + await SendAsync(frame.Raw); + + return done; + } + + private Task SendHeadersAsync(int streamId, Http2HeadersFrameFlags flags, byte[] headerBlock) + { + var frame = new Http2Frame(); + + frame.PrepareHeaders(flags, streamId); + frame.Length = headerBlock.Length; + headerBlock.CopyTo(frame.HeadersPayload); + + return SendAsync(frame.Raw); + } + + private Task SendInvalidHeadersFrameAsync(int streamId, int frameLength, byte padLength) + { + Assert.True(padLength >= frameLength, $"{nameof(padLength)} must be greater than or equal to {nameof(frameLength)} to create an invalid frame."); + + var frame = new Http2Frame(); + + frame.PrepareHeaders(Http2HeadersFrameFlags.PADDED, streamId); + frame.Payload[0] = padLength; + + // Set length last so .Payload can be written to + frame.Length = frameLength; + + return SendAsync(frame.Raw); + } + + private Task SendIncompleteHeadersFrameAsync(int streamId) + { + var frame = new Http2Frame(); + + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS, streamId); + frame.Length = 3; + + // Set up an incomplete Literal Header Field w/ Incremental Indexing frame, + // with an incomplete new name + frame.Payload[0] = 0; + frame.Payload[1] = 2; + frame.Payload[2] = (byte)'a'; + + return SendAsync(frame.Raw); + } + + private async Task SendContinuationAsync(int streamId, Http2ContinuationFrameFlags flags) + { + var frame = new Http2Frame(); + + frame.PrepareContinuation(flags, streamId); + var done = _hpackEncoder.Encode(frame.Payload, out var length); + frame.Length = length; + + await SendAsync(frame.Raw); + + return done; + } + + private async Task SendContinuationAsync(int streamId, Http2ContinuationFrameFlags flags, byte[] payload) + { + var frame = new Http2Frame(); + + frame.PrepareContinuation(flags, streamId); + frame.Length = payload.Length; + payload.CopyTo(frame.Payload); + + await SendAsync(frame.Raw); + } + + private Task SendEmptyContinuationFrameAsync(int streamId, Http2ContinuationFrameFlags flags) + { + var frame = new Http2Frame(); + + frame.PrepareContinuation(flags, streamId); + frame.Length = 0; + + return SendAsync(frame.Raw); + } + + private Task SendIncompleteContinuationFrameAsync(int streamId) + { + var frame = new Http2Frame(); + + frame.PrepareContinuation(Http2ContinuationFrameFlags.END_HEADERS, streamId); + frame.Length = 3; + + // Set up an incomplete Literal Header Field w/ Incremental Indexing frame, + // with an incomplete new name + frame.Payload[0] = 0; + frame.Payload[1] = 2; + frame.Payload[2] = (byte)'a'; + + return SendAsync(frame.Raw); + } + + private Task SendDataAsync(int streamId, Span data, bool endStream) + { + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + frame.Length = data.Length; + frame.DataFlags = endStream ? Http2DataFrameFlags.END_STREAM : Http2DataFrameFlags.NONE; + data.CopyTo(frame.DataPayload); + + return SendAsync(frame.Raw); + } + + private Task SendDataWithPaddingAsync(int streamId, Span data, byte padLength, bool endStream) + { + var frame = new Http2Frame(); + + frame.PrepareData(streamId, padLength); + frame.Length = data.Length + 1 + padLength; + data.CopyTo(frame.DataPayload); + + if (endStream) + { + frame.DataFlags |= Http2DataFrameFlags.END_STREAM; + } + + return SendAsync(frame.Raw); + } + + private Task SendInvalidDataFrameAsync(int streamId, int frameLength, byte padLength) + { + Assert.True(padLength >= frameLength, $"{nameof(padLength)} must be greater than or equal to {nameof(frameLength)} to create an invalid frame."); + + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + frame.DataFlags = Http2DataFrameFlags.PADDED; + frame.Payload[0] = padLength; + + // Set length last so .Payload can be written to + frame.Length = frameLength; + + return SendAsync(frame.Raw); + } + + private Task SendPingAsync(Http2PingFrameFlags flags) + { + var pingFrame = new Http2Frame(); + pingFrame.PreparePing(flags); + return SendAsync(pingFrame.Raw); + } + + private Task SendPingWithInvalidLengthAsync(int length) + { + var pingFrame = new Http2Frame(); + pingFrame.PreparePing(Http2PingFrameFlags.NONE); + pingFrame.Length = length; + return SendAsync(pingFrame.Raw); + } + + private Task SendPingWithInvalidStreamIdAsync(int streamId) + { + Assert.NotEqual(0, streamId); + + var pingFrame = new Http2Frame(); + pingFrame.PreparePing(Http2PingFrameFlags.NONE); + pingFrame.StreamId = streamId; + return SendAsync(pingFrame.Raw); + } + + private Task SendPriorityAsync(int streamId, int streamDependency = 0) + { + var priorityFrame = new Http2Frame(); + priorityFrame.PreparePriority(streamId, streamDependency: streamDependency, exclusive: false, weight: 0); + return SendAsync(priorityFrame.Raw); + } + + private Task SendInvalidPriorityFrameAsync(int streamId, int length) + { + var priorityFrame = new Http2Frame(); + priorityFrame.PreparePriority(streamId, streamDependency: 0, exclusive: false, weight: 0); + priorityFrame.Length = length; + return SendAsync(priorityFrame.Raw); + } + + private Task SendRstStreamAsync(int streamId) + { + var rstStreamFrame = new Http2Frame(); + rstStreamFrame.PrepareRstStream(streamId, Http2ErrorCode.CANCEL); + return SendAsync(rstStreamFrame.Raw); + } + + private Task SendInvalidRstStreamFrameAsync(int streamId, int length) + { + var frame = new Http2Frame(); + frame.PrepareRstStream(streamId, Http2ErrorCode.CANCEL); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendGoAwayAsync() + { + var frame = new Http2Frame(); + frame.PrepareGoAway(0, Http2ErrorCode.NO_ERROR); + return SendAsync(frame.Raw); + } + + private Task SendInvalidGoAwayFrameAsync() + { + var frame = new Http2Frame(); + frame.PrepareGoAway(0, Http2ErrorCode.NO_ERROR); + frame.StreamId = 1; + return SendAsync(frame.Raw); + } + + private Task SendWindowUpdateAsync(int streamId, int sizeIncrement) + { + var frame = new Http2Frame(); + frame.PrepareWindowUpdate(streamId, sizeIncrement); + return SendAsync(frame.Raw); + } + + private Task SendInvalidWindowUpdateAsync(int streamId, int sizeIncrement, int length) + { + var frame = new Http2Frame(); + frame.PrepareWindowUpdate(streamId, sizeIncrement); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendUnknownFrameTypeAsync(int streamId, int frameType) + { + var frame = new Http2Frame(); + frame.StreamId = streamId; + frame.Type = (Http2FrameType)frameType; + frame.Length = 0; + return SendAsync(frame.Raw); + } + + private async Task ReceiveFrameAsync() + { + var frame = new Http2Frame(); + + while (true) + { + var result = await _pair.Application.Input.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; + + try + { + Assert.True(buffer.Length > 0); + + if (Http2FrameReader.ReadFrame(buffer, frame, out consumed, out examined)) + { + return frame; + } + } + finally + { + _pair.Application.Input.AdvanceTo(consumed, examined); + } + } + } + + private async Task ReceiveSettingsAck() + { + var frame = await ReceiveFrameAsync(); + + Assert.Equal(Http2FrameType.SETTINGS, frame.Type); + Assert.Equal(Http2SettingsFrameFlags.ACK, frame.SettingsFlags); + } + + private async Task ExpectAsync(Http2FrameType type, int withLength, byte withFlags, int withStreamId) + { + var frame = await ReceiveFrameAsync(); + + Assert.Equal(type, frame.Type); + Assert.Equal(withLength, frame.Length); + Assert.Equal(withFlags, frame.Flags); + Assert.Equal(withStreamId, frame.StreamId); + + return frame; + } + + private Task StopConnectionAsync(int expectedLastStreamId, bool ignoreNonGoAwayFrames) + { + _pair.Application.Output.Complete(); + + return WaitForConnectionStopAsync(expectedLastStreamId, ignoreNonGoAwayFrames); + } + + private Task WaitForConnectionStopAsync(int expectedLastStreamId, bool ignoreNonGoAwayFrames) + { + return WaitForConnectionErrorAsync(ignoreNonGoAwayFrames, expectedLastStreamId, Http2ErrorCode.NO_ERROR, expectedErrorMessage: null); + } + + private async Task WaitForConnectionErrorAsync(bool ignoreNonGoAwayFrames, int expectedLastStreamId, Http2ErrorCode expectedErrorCode, string expectedErrorMessage) + where TException : Exception + { + var frame = await ReceiveFrameAsync(); + + if (ignoreNonGoAwayFrames) + { + while (frame.Type != Http2FrameType.GOAWAY) + { + frame = await ReceiveFrameAsync(); + } + } + + Assert.Equal(Http2FrameType.GOAWAY, frame.Type); + Assert.Equal(8, frame.Length); + Assert.Equal(0, frame.Flags); + Assert.Equal(0, frame.StreamId); + Assert.Equal(expectedLastStreamId, frame.GoAwayLastStreamId); + Assert.Equal(expectedErrorCode, frame.GoAwayErrorCode); + + if (expectedErrorMessage != null) + { + var message = Assert.Single(_logger.Messages, m => m.Exception is TException); + Assert.Contains(expectedErrorMessage, message.Exception.Message); + } + + await _connectionTask; + _pair.Application.Output.Complete(); + } + + private async Task WaitForStreamErrorAsync(bool ignoreNonRstStreamFrames, int expectedStreamId, Http2ErrorCode expectedErrorCode, string expectedErrorMessage) + { + var frame = await ReceiveFrameAsync(); + + if (ignoreNonRstStreamFrames) + { + while (frame.Type != Http2FrameType.RST_STREAM) + { + frame = await ReceiveFrameAsync(); + } + } + + Assert.Equal(Http2FrameType.RST_STREAM, frame.Type); + Assert.Equal(4, frame.Length); + Assert.Equal(0, frame.Flags); + Assert.Equal(expectedStreamId, frame.StreamId); + Assert.Equal(expectedErrorCode, frame.RstStreamErrorCode); + + if (expectedErrorMessage != null) + { + var message = Assert.Single(_logger.Messages, m => m.Exception is Http2StreamErrorException); + Assert.Contains(expectedErrorMessage, message.Exception.Message); + } + } + + private void VerifyDecodedRequestHeaders(IEnumerable> expectedHeaders) + { + foreach (var header in expectedHeaders) + { + Assert.True(_receivedHeaders.TryGetValue(header.Key, out var value), header.Key); + Assert.Equal(header.Value, value, ignoreCase: true); + } + } + + public static TheoryData UpperCaseHeaderNameData + { + get + { + // We can't use HPackEncoder here because it will convert header names to lowercase + var headerName = "abcdefghijklmnopqrstuvwxyz"; + + var headerBlockStart = new byte[] + { + 0x82, // Indexed Header Field - :method: GET + 0x84, // Indexed Header Field - :path: / + 0x86, // Indexed Header Field - :scheme: http + 0x00, // Literal Header Field without Indexing - New Name + (byte)headerName.Length, // Header name length + }; + + var headerBlockEnd = new byte[] + { + 0x01, // Header value length + 0x30 // "0" + }; + + var data = new TheoryData(); + + for (var i = 0; i < headerName.Length; i++) + { + var bytes = Encoding.ASCII.GetBytes(headerName); + bytes[i] &= 0xdf; + + var headerBlock = headerBlockStart.Concat(bytes).Concat(headerBlockEnd).ToArray(); + data.Add(headerBlock); + } + + return data; + } + } + + public static TheoryData>> DuplicatePseudoHeaderFieldData + { + get + { + var data = new TheoryData>>(); + var requestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "http"), + }; + + foreach (var headerField in requestHeaders) + { + var headers = requestHeaders.Concat(new[] { new KeyValuePair(headerField.Key, headerField.Value) }); + data.Add(headers); + } + + return data; + } + } + + public static TheoryData>> MissingPseudoHeaderFieldData + { + get + { + var data = new TheoryData>>(); + var requestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + }; + + foreach (var headerField in requestHeaders) + { + var headers = requestHeaders.Except(new[] { headerField }); + data.Add(headers); + } + + return data; + } + } + + public static TheoryData>> ConnectMissingPseudoHeaderFieldData + { + get + { + var data = new TheoryData>>(); + var methodHeader = new[] { new KeyValuePair(":method", "CONNECT") }; + var requestHeaders = new[] + { + new KeyValuePair(":path", "/"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair(":authority", "127.0.0.1"), + }; + + foreach (var headerField in requestHeaders) + { + var headers = methodHeader.Concat(requestHeaders.Except(new[] { headerField })); + data.Add(headers); + } + + return data; + } + } + + public static TheoryData>> PseudoHeaderFieldAfterRegularHeadersData + { + get + { + var data = new TheoryData>>(); + var requestHeaders = new[] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "http"), + new KeyValuePair("content-length", "0") + }; + + foreach (var headerField in requestHeaders.Where(h => h.Key.StartsWith(":"))) + { + var headers = requestHeaders.Except(new[] { headerField }).Concat(new[] { headerField }); + data.Add(headers); + } + + return data; + } + } + + public static TheoryData IllegalTrailerData + { + get + { + // We can't use HPackEncoder here because it will convert header names to lowercase + var data = new TheoryData(); + + // Indexed Header Field - :method: GET + data.Add(new byte[] { 0x82 }, CoreStrings.Http2ErrorTrailersContainPseudoHeaderField); + + // Indexed Header Field - :path: / + data.Add(new byte[] { 0x84 }, CoreStrings.Http2ErrorTrailersContainPseudoHeaderField); + + // Indexed Header Field - :scheme: http + data.Add(new byte[] { 0x86 }, CoreStrings.Http2ErrorTrailersContainPseudoHeaderField); + + // Literal Header Field without Indexing - Indexed Name - :authority: 127.0.0.1 + data.Add(new byte[] { 0x01, 0x09 }.Concat(Encoding.ASCII.GetBytes("127.0.0.1")).ToArray(), CoreStrings.Http2ErrorTrailersContainPseudoHeaderField); + + // Literal Header Field without Indexing - New Name - contains-Uppercase: 0 + data.Add(new byte[] { 0x00, 0x12 } + .Concat(Encoding.ASCII.GetBytes("contains-Uppercase")) + .Concat(new byte[] { 0x01, (byte)'0' }) + .ToArray(), CoreStrings.Http2ErrorTrailerNameUppercase); + + return data; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs b/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs new file mode 100644 index 0000000000..015aa75881 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs @@ -0,0 +1,60 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpConnectionManagerTests + { + [Fact] + public void UnrootedConnectionsGetRemovedFromHeartbeat() + { + var connectionId = "0"; + var trace = new Mock(); + var httpConnectionManager = new HttpConnectionManager(trace.Object, ResourceCounter.Unlimited); + + // Create HttpConnection in inner scope so it doesn't get rooted by the current frame. + UnrootedConnectionsGetRemovedFromHeartbeatInnerScope(connectionId, httpConnectionManager, trace); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + + var connectionCount = 0; + httpConnectionManager.Walk(_ => connectionCount++); + + Assert.Equal(0, connectionCount); + trace.Verify(t => t.ApplicationNeverCompleted(connectionId), Times.Once()); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void UnrootedConnectionsGetRemovedFromHeartbeatInnerScope( + string connectionId, + HttpConnectionManager httpConnectionManager, + Mock trace) + { + var httpConnection = new HttpConnection(new HttpConnectionContext + { + ServiceContext = new TestServiceContext(), + ConnectionId = connectionId + }); + + httpConnectionManager.AddConnection(0, httpConnection); + + var connectionCount = 0; + httpConnectionManager.Walk(_ => connectionCount++); + + Assert.Equal(1, connectionCount); + trace.Verify(t => t.ApplicationNeverCompleted(connectionId), Times.Never()); + + // Ensure httpConnection doesn't get GC'd before this point. + GC.KeepAlive(httpConnection); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpConnectionTests.cs b/src/Servers/Kestrel/Core/test/HttpConnectionTests.cs new file mode 100644 index 0000000000..c8dbf28ba0 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpConnectionTests.cs @@ -0,0 +1,594 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpConnectionTests : IDisposable + { + private readonly MemoryPool _memoryPool; + private readonly HttpConnectionContext _httpConnectionContext; + private readonly HttpConnection _httpConnection; + + public HttpConnectionTests() + { + _memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + + _httpConnectionContext = new HttpConnectionContext + { + ConnectionId = "0123456789", + ConnectionContext = Mock.Of(), + ConnectionAdapters = new List(), + ConnectionFeatures = connectionFeatures, + MemoryPool = _memoryPool, + HttpConnectionId = long.MinValue, + Application = pair.Application, + Transport = pair.Transport, + ServiceContext = new TestServiceContext + { + SystemClock = new SystemClock() + } + }; + + _httpConnection = new HttpConnection(_httpConnectionContext); + } + + public void Dispose() + { + _memoryPool.Dispose(); + } + + [Fact] + public void DoesNotTimeOutWhenDebuggerIsAttached() + { + var mockDebugger = new Mock(); + mockDebugger.SetupGet(g => g.IsAttached).Returns(true); + _httpConnection.Debugger = mockDebugger.Object; + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + + var now = DateTimeOffset.Now; + _httpConnection.Tick(now); + _httpConnection.SetTimeout(1, TimeoutAction.SendTimeoutResponse); + _httpConnection.Tick(now.AddTicks(2).Add(Heartbeat.Interval)); + + Assert.False(_httpConnection.RequestTimedOut); + } + + [Fact] + public void DoesNotTimeOutWhenRequestBodyDoesNotSatisfyMinimumDataRateButDebuggerIsAttached() + { + var mockDebugger = new Mock(); + mockDebugger.SetupGet(g => g.IsAttached).Returns(true); + _httpConnection.Debugger = mockDebugger.Object; + var bytesPerSecond = 100; + var mockLogger = new Mock(); + mockLogger.Setup(l => l.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny())).Throws(new InvalidOperationException("Should not log")); + + TickBodyWithMinimumDataRate(mockLogger.Object, bytesPerSecond); + + Assert.False(_httpConnection.RequestTimedOut); + } + + [Fact] + public void TimesOutWhenRequestBodyDoesNotSatisfyMinimumDataRate() + { + var bytesPerSecond = 100; + var mockLogger = new Mock(); + TickBodyWithMinimumDataRate(mockLogger.Object, bytesPerSecond); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Once); + } + + private void TickBodyWithMinimumDataRate(IKestrelTrace logger, int bytesPerSecond) + { + var gracePeriod = TimeSpan.FromSeconds(5); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: bytesPerSecond, gracePeriod: gracePeriod); + + _httpConnectionContext.ServiceContext.Log = logger; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + // Initialize timestamp + var now = DateTimeOffset.UtcNow; + _httpConnection.Tick(now); + + _httpConnection.StartTimingReads(); + + // Tick after grace period w/ low data rate + now += gracePeriod + TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(1); + _httpConnection.Tick(now); + } + + [Fact] + public void RequestBodyMinimumDataRateNotEnforcedDuringGracePeriod() + { + var bytesPerSecond = 100; + var gracePeriod = TimeSpan.FromSeconds(2); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: bytesPerSecond, gracePeriod: gracePeriod); + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + // Initialize timestamp + var now = DateTimeOffset.UtcNow; + _httpConnection.Tick(now); + + _httpConnection.StartTimingReads(); + + // Tick during grace period w/ low data rate + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(10); + _httpConnection.Tick(now); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Never); + + // Tick after grace period w/ low data rate + now += TimeSpan.FromSeconds(2); + _httpConnection.BytesRead(10); + _httpConnection.Tick(now); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Once); + } + + [Fact] + public void RequestBodyDataRateIsAveragedOverTimeSpentReadingRequestBody() + { + var bytesPerSecond = 100; + var gracePeriod = TimeSpan.FromSeconds(2); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: bytesPerSecond, gracePeriod: gracePeriod); + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + // Initialize timestamp + var now = DateTimeOffset.UtcNow; + _httpConnection.Tick(now); + + _httpConnection.StartTimingReads(); + + // Set base data rate to 200 bytes/second + now += gracePeriod; + _httpConnection.BytesRead(400); + _httpConnection.Tick(now); + + // Data rate: 200 bytes/second + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(200); + _httpConnection.Tick(now); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Never); + + // Data rate: 150 bytes/second + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(0); + _httpConnection.Tick(now); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Never); + + // Data rate: 120 bytes/second + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(0); + _httpConnection.Tick(now); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Never); + + // Data rate: 100 bytes/second + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(0); + _httpConnection.Tick(now); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Never); + + // Data rate: ~85 bytes/second + now += TimeSpan.FromSeconds(1); + _httpConnection.BytesRead(0); + _httpConnection.Tick(now); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + mockLogger.Verify(logger => + logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), bytesPerSecond), Times.Once); + } + + [Fact] + public void RequestBodyDataRateNotComputedOnPausedTime() + { + var systemClock = new MockSystemClock(); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(2)); + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + // Initialize timestamp + _httpConnection.Tick(systemClock.UtcNow); + + _httpConnection.StartTimingReads(); + + // Tick at 3s, expected counted time is 3s, expected data rate is 200 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(3); + _httpConnection.BytesRead(600); + _httpConnection.Tick(systemClock.UtcNow); + + // Pause at 3.5s + systemClock.UtcNow += TimeSpan.FromSeconds(0.5); + _httpConnection.PauseTimingReads(); + + // Tick at 4s, expected counted time is 4s (first tick after pause goes through), expected data rate is 150 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(0.5); + _httpConnection.Tick(systemClock.UtcNow); + + // Tick at 6s, expected counted time is 4s, expected data rate is 150 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(2); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + + // Resume at 6.5s + systemClock.UtcNow += TimeSpan.FromSeconds(0.5); + _httpConnection.ResumeTimingReads(); + + // Tick at 9s, expected counted time is 6s, expected data rate is 100 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(1.5); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + + // Tick at 10s, expected counted time is 7s, expected data rate drops below 100 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public void ReadTimingNotPausedWhenResumeCalledBeforeNextTick() + { + var systemClock = new MockSystemClock(); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(2)); + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + // Initialize timestamp + _httpConnection.Tick(systemClock.UtcNow); + + _httpConnection.StartTimingReads(); + + // Tick at 2s, expected counted time is 2s, expected data rate is 100 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(2); + _httpConnection.BytesRead(200); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + + // Pause at 2.25s + systemClock.UtcNow += TimeSpan.FromSeconds(0.25); + _httpConnection.PauseTimingReads(); + + // Resume at 2.5s + systemClock.UtcNow += TimeSpan.FromSeconds(0.25); + _httpConnection.ResumeTimingReads(); + + // Tick at 3s, expected counted time is 3s, expected data rate is 100 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(0.5); + _httpConnection.BytesRead(100); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + + // Tick at 4s, expected counted time is 4s, expected data rate drops below 100 bytes/second + systemClock.UtcNow += TimeSpan.FromSeconds(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + mockLogger.Verify( + logger => logger.RequestBodyMininumDataRateNotSatisfied(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public void ReadTimingNotEnforcedWhenTimeoutIsSet() + { + var systemClock = new MockSystemClock(); + var timeout = TimeSpan.FromSeconds(5); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinRequestBodyDataRate = + new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(2)); + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + + var startTime = systemClock.UtcNow; + + // Initialize timestamp + _httpConnection.Tick(startTime); + + _httpConnection.StartTimingReads(); + + _httpConnection.SetTimeout(timeout.Ticks, TimeoutAction.StopProcessingNextRequest); + + // Tick beyond grace period with low data rate + systemClock.UtcNow += TimeSpan.FromSeconds(3); + _httpConnection.BytesRead(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out + Assert.False(_httpConnection.RequestTimedOut); + + // Tick just past timeout period, adjusted by Heartbeat.Interval + systemClock.UtcNow = startTime + timeout + Heartbeat.Interval + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Timed out + Assert.True(_httpConnection.RequestTimedOut); + } + + [Fact] + public async Task WriteTimingAbortsConnectionWhenWriteDoesNotCompleteWithMinimumDataRate() + { + var systemClock = new MockSystemClock(); + var aborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinResponseDataRate = + new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(2)); + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + _httpConnection.Http1Connection.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + // Initialize timestamp + _httpConnection.Tick(systemClock.UtcNow); + + // Should complete within 4 seconds, but the timeout is adjusted by adding Heartbeat.Interval + _httpConnection.StartTimingWrite(400); + + // Tick just past 4s plus Heartbeat.Interval + systemClock.UtcNow += TimeSpan.FromSeconds(4) + Heartbeat.Interval + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.True(_httpConnection.RequestTimedOut); + await aborted.Task.DefaultTimeout(); + } + + [Fact] + public async Task WriteTimingAbortsConnectionWhenSmallWriteDoesNotCompleteWithinGracePeriod() + { + var systemClock = new MockSystemClock(); + var minResponseDataRate = new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(5)); + var aborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinResponseDataRate = minResponseDataRate; + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + _httpConnection.Http1Connection.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + // Initialize timestamp + var startTime = systemClock.UtcNow; + _httpConnection.Tick(startTime); + + // Should complete within 1 second, but the timeout is adjusted by adding Heartbeat.Interval + _httpConnection.StartTimingWrite(100); + + // Tick just past 1s plus Heartbeat.Interval + systemClock.UtcNow += TimeSpan.FromSeconds(1) + Heartbeat.Interval + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Still within grace period, not timed out + Assert.False(_httpConnection.RequestTimedOut); + + // Tick just past grace period (adjusted by Heartbeat.Interval) + systemClock.UtcNow = startTime + minResponseDataRate.GracePeriod + Heartbeat.Interval + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.True(_httpConnection.RequestTimedOut); + await aborted.Task.DefaultTimeout(); + } + + [Fact] + public async Task WriteTimingTimeoutPushedOnConcurrentWrite() + { + var systemClock = new MockSystemClock(); + var aborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinResponseDataRate = + new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(2)); + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + _httpConnection.Http1Connection.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + // Initialize timestamp + _httpConnection.Tick(systemClock.UtcNow); + + // Should complete within 5 seconds, but the timeout is adjusted by adding Heartbeat.Interval + _httpConnection.StartTimingWrite(500); + + // Start a concurrent write after 3 seconds, which should complete within 3 seconds (adjusted by Heartbeat.Interval) + _httpConnection.StartTimingWrite(300); + + // Tick just past 5s plus Heartbeat.Interval, when the first write should have completed + systemClock.UtcNow += TimeSpan.FromSeconds(5) + Heartbeat.Interval + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + // Not timed out because the timeout was pushed by the second write + Assert.False(_httpConnection.RequestTimedOut); + + // Complete the first write, this should have no effect on the timeout + _httpConnection.StopTimingWrite(); + + // Tick just past +3s, when the second write should have completed + systemClock.UtcNow += TimeSpan.FromSeconds(3) + TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.True(_httpConnection.RequestTimedOut); + await aborted.Task.DefaultTimeout(); + } + + [Fact] + public async Task WriteTimingAbortsConnectionWhenRepeadtedSmallWritesDoNotCompleteWithMinimumDataRate() + { + var systemClock = new MockSystemClock(); + var minResponseDataRate = new MinDataRate(bytesPerSecond: 100, gracePeriod: TimeSpan.FromSeconds(5)); + var numWrites = 5; + var writeSize = 100; + var aborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + _httpConnectionContext.ServiceContext.ServerOptions.Limits.MinResponseDataRate = minResponseDataRate; + _httpConnectionContext.ServiceContext.SystemClock = systemClock; + + var mockLogger = new Mock(); + _httpConnectionContext.ServiceContext.Log = mockLogger.Object; + + _httpConnection.Initialize(_httpConnectionContext.Transport, _httpConnectionContext.Application); + _httpConnection.Http1Connection.Reset(); + _httpConnection.Http1Connection.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + // Initialize timestamp + var startTime = systemClock.UtcNow; + _httpConnection.Tick(startTime); + + // 5 consecutive 100 byte writes. + for (var i = 0; i < numWrites - 1; i++) + { + _httpConnection.StartTimingWrite(writeSize); + _httpConnection.StopTimingWrite(); + } + + // Stall the last write. + _httpConnection.StartTimingWrite(writeSize); + + // Move the clock forward Heartbeat.Interval + MinDataRate.GracePeriod + 4 seconds. + // The grace period should only be added for the first write. The subsequent 4 100 byte writes should add 1 second each to the timeout given the 100 byte/s min rate. + systemClock.UtcNow += Heartbeat.Interval + minResponseDataRate.GracePeriod + TimeSpan.FromSeconds((numWrites - 1) * writeSize / minResponseDataRate.BytesPerSecond); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.False(_httpConnection.RequestTimedOut); + + // On more tick forward triggers the timeout. + systemClock.UtcNow += TimeSpan.FromTicks(1); + _httpConnection.Tick(systemClock.UtcNow); + + Assert.True(_httpConnection.RequestTimedOut); + await aborted.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpHeadersTests.cs b/src/Servers/Kestrel/Core/test/HttpHeadersTests.cs new file mode 100644 index 0000000000..2aef432898 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpHeadersTests.cs @@ -0,0 +1,277 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpHeadersTests + { + [Theory] + [InlineData("", ConnectionOptions.None)] + [InlineData(",", ConnectionOptions.None)] + [InlineData(" ,", ConnectionOptions.None)] + [InlineData(" , ", ConnectionOptions.None)] + [InlineData(",,", ConnectionOptions.None)] + [InlineData(" ,,", ConnectionOptions.None)] + [InlineData(",, ", ConnectionOptions.None)] + [InlineData(" , ,", ConnectionOptions.None)] + [InlineData(" , , ", ConnectionOptions.None)] + [InlineData("keep-alive", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive, upgrade", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("keep-alive,upgrade", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("upgrade, keep-alive", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("upgrade,keep-alive", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("upgrade,,keep-alive", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("keep-alive,", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive,,", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive, ", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive, ,", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive, , ", ConnectionOptions.KeepAlive)] + [InlineData("keep-alive ,", ConnectionOptions.KeepAlive)] + [InlineData(",keep-alive", ConnectionOptions.KeepAlive)] + [InlineData(", keep-alive", ConnectionOptions.KeepAlive)] + [InlineData(",,keep-alive", ConnectionOptions.KeepAlive)] + [InlineData(", ,keep-alive", ConnectionOptions.KeepAlive)] + [InlineData(",, keep-alive", ConnectionOptions.KeepAlive)] + [InlineData(", , keep-alive", ConnectionOptions.KeepAlive)] + [InlineData("upgrade,", ConnectionOptions.Upgrade)] + [InlineData("upgrade,,", ConnectionOptions.Upgrade)] + [InlineData("upgrade, ", ConnectionOptions.Upgrade)] + [InlineData("upgrade, ,", ConnectionOptions.Upgrade)] + [InlineData("upgrade, , ", ConnectionOptions.Upgrade)] + [InlineData("upgrade ,", ConnectionOptions.Upgrade)] + [InlineData(",upgrade", ConnectionOptions.Upgrade)] + [InlineData(", upgrade", ConnectionOptions.Upgrade)] + [InlineData(",,upgrade", ConnectionOptions.Upgrade)] + [InlineData(", ,upgrade", ConnectionOptions.Upgrade)] + [InlineData(",, upgrade", ConnectionOptions.Upgrade)] + [InlineData(", , upgrade", ConnectionOptions.Upgrade)] + [InlineData("close,", ConnectionOptions.Close)] + [InlineData("close,,", ConnectionOptions.Close)] + [InlineData("close, ", ConnectionOptions.Close)] + [InlineData("close, ,", ConnectionOptions.Close)] + [InlineData("close, , ", ConnectionOptions.Close)] + [InlineData("close ,", ConnectionOptions.Close)] + [InlineData(",close", ConnectionOptions.Close)] + [InlineData(", close", ConnectionOptions.Close)] + [InlineData(",,close", ConnectionOptions.Close)] + [InlineData(", ,close", ConnectionOptions.Close)] + [InlineData(",, close", ConnectionOptions.Close)] + [InlineData(", , close", ConnectionOptions.Close)] + [InlineData("kupgrade", ConnectionOptions.None)] + [InlineData("keupgrade", ConnectionOptions.None)] + [InlineData("ukeep-alive", ConnectionOptions.None)] + [InlineData("upkeep-alive", ConnectionOptions.None)] + [InlineData("k,upgrade", ConnectionOptions.Upgrade)] + [InlineData("u,keep-alive", ConnectionOptions.KeepAlive)] + [InlineData("ke,upgrade", ConnectionOptions.Upgrade)] + [InlineData("up,keep-alive", ConnectionOptions.KeepAlive)] + [InlineData("close", ConnectionOptions.Close)] + [InlineData("upgrade,close", ConnectionOptions.Close | ConnectionOptions.Upgrade)] + [InlineData("close,upgrade", ConnectionOptions.Close | ConnectionOptions.Upgrade)] + [InlineData("keep-alive2", ConnectionOptions.None)] + [InlineData("keep-alive2 ", ConnectionOptions.None)] + [InlineData("keep-alive2 ,", ConnectionOptions.None)] + [InlineData("keep-alive2,", ConnectionOptions.None)] + [InlineData("upgrade2", ConnectionOptions.None)] + [InlineData("upgrade2 ", ConnectionOptions.None)] + [InlineData("upgrade2 ,", ConnectionOptions.None)] + [InlineData("upgrade2,", ConnectionOptions.None)] + [InlineData("close2", ConnectionOptions.None)] + [InlineData("close2 ", ConnectionOptions.None)] + [InlineData("close2 ,", ConnectionOptions.None)] + [InlineData("close2,", ConnectionOptions.None)] + [InlineData("keep-alivekeep-alive", ConnectionOptions.None)] + [InlineData("keep-aliveupgrade", ConnectionOptions.None)] + [InlineData("upgradeupgrade", ConnectionOptions.None)] + [InlineData("upgradekeep-alive", ConnectionOptions.None)] + [InlineData("closeclose", ConnectionOptions.None)] + [InlineData("closeupgrade", ConnectionOptions.None)] + [InlineData("upgradeclose", ConnectionOptions.None)] + [InlineData("keep-alive 2", ConnectionOptions.None)] + [InlineData("upgrade 2", ConnectionOptions.None)] + [InlineData("keep-alive 2, close", ConnectionOptions.Close)] + [InlineData("upgrade 2, close", ConnectionOptions.Close)] + [InlineData("close, keep-alive 2", ConnectionOptions.Close)] + [InlineData("close, upgrade 2", ConnectionOptions.Close)] + [InlineData("close 2, upgrade", ConnectionOptions.Upgrade)] + [InlineData("upgrade, close 2", ConnectionOptions.Upgrade)] + [InlineData("k2ep-alive", ConnectionOptions.None)] + [InlineData("ke2p-alive", ConnectionOptions.None)] + [InlineData("u2grade", ConnectionOptions.None)] + [InlineData("up2rade", ConnectionOptions.None)] + [InlineData("c2ose", ConnectionOptions.None)] + [InlineData("cl2se", ConnectionOptions.None)] + [InlineData("k2ep-alive,", ConnectionOptions.None)] + [InlineData("ke2p-alive,", ConnectionOptions.None)] + [InlineData("u2grade,", ConnectionOptions.None)] + [InlineData("up2rade,", ConnectionOptions.None)] + [InlineData("c2ose,", ConnectionOptions.None)] + [InlineData("cl2se,", ConnectionOptions.None)] + [InlineData("k2ep-alive ", ConnectionOptions.None)] + [InlineData("ke2p-alive ", ConnectionOptions.None)] + [InlineData("u2grade ", ConnectionOptions.None)] + [InlineData("up2rade ", ConnectionOptions.None)] + [InlineData("c2ose ", ConnectionOptions.None)] + [InlineData("cl2se ", ConnectionOptions.None)] + [InlineData("k2ep-alive ,", ConnectionOptions.None)] + [InlineData("ke2p-alive ,", ConnectionOptions.None)] + [InlineData("u2grade ,", ConnectionOptions.None)] + [InlineData("up2rade ,", ConnectionOptions.None)] + [InlineData("c2ose ,", ConnectionOptions.None)] + [InlineData("cl2se ,", ConnectionOptions.None)] + public void TestParseConnection(string connection, ConnectionOptions expectedConnectionOptions) + { + var connectionOptions = HttpHeaders.ParseConnection(connection); + Assert.Equal(expectedConnectionOptions, connectionOptions); + } + + [Theory] + [InlineData("keep-alive", "upgrade", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("upgrade", "keep-alive", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("keep-alive", "", ConnectionOptions.KeepAlive)] + [InlineData("", "keep-alive", ConnectionOptions.KeepAlive)] + [InlineData("upgrade", "", ConnectionOptions.Upgrade)] + [InlineData("", "upgrade", ConnectionOptions.Upgrade)] + [InlineData("keep-alive, upgrade", "", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("upgrade, keep-alive", "", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("", "keep-alive, upgrade", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("", "upgrade, keep-alive", ConnectionOptions.KeepAlive | ConnectionOptions.Upgrade)] + [InlineData("", "", ConnectionOptions.None)] + [InlineData("close", "", ConnectionOptions.Close)] + [InlineData("", "close", ConnectionOptions.Close)] + [InlineData("close", "upgrade", ConnectionOptions.Close | ConnectionOptions.Upgrade)] + [InlineData("upgrade", "close", ConnectionOptions.Close | ConnectionOptions.Upgrade)] + public void TestParseConnectionMultipleValues(string value1, string value2, ConnectionOptions expectedConnectionOptions) + { + var connection = new StringValues(new[] { value1, value2 }); + var connectionOptions = HttpHeaders.ParseConnection(connection); + Assert.Equal(expectedConnectionOptions, connectionOptions); + } + + [Theory] + [InlineData("", TransferCoding.None)] + [InlineData(",,", TransferCoding.None)] + [InlineData(" ,,", TransferCoding.None)] + [InlineData(",, ", TransferCoding.None)] + [InlineData(" , ,", TransferCoding.None)] + [InlineData(" , , ", TransferCoding.None)] + [InlineData("chunked,", TransferCoding.Chunked)] + [InlineData("chunked,,", TransferCoding.Chunked)] + [InlineData("chunked, ", TransferCoding.Chunked)] + [InlineData("chunked, ,", TransferCoding.Chunked)] + [InlineData("chunked, , ", TransferCoding.Chunked)] + [InlineData("chunked ,", TransferCoding.Chunked)] + [InlineData(",chunked", TransferCoding.Chunked)] + [InlineData(", chunked", TransferCoding.Chunked)] + [InlineData(",,chunked", TransferCoding.Chunked)] + [InlineData(", ,chunked", TransferCoding.Chunked)] + [InlineData(",, chunked", TransferCoding.Chunked)] + [InlineData(", , chunked", TransferCoding.Chunked)] + [InlineData("chunked, gzip", TransferCoding.Other)] + [InlineData("chunked,compress", TransferCoding.Other)] + [InlineData("deflate, chunked", TransferCoding.Chunked)] + [InlineData("gzip,chunked", TransferCoding.Chunked)] + [InlineData("compress,,chunked", TransferCoding.Chunked)] + [InlineData("chunkedchunked", TransferCoding.Other)] + [InlineData("chunked2", TransferCoding.Other)] + [InlineData("chunked 2", TransferCoding.Other)] + [InlineData("2chunked", TransferCoding.Other)] + [InlineData("c2unked", TransferCoding.Other)] + [InlineData("ch2nked", TransferCoding.Other)] + [InlineData("chunked 2, gzip", TransferCoding.Other)] + [InlineData("chunked2, gzip", TransferCoding.Other)] + [InlineData("gzip, chunked 2", TransferCoding.Other)] + [InlineData("gzip, chunked2", TransferCoding.Other)] + public void TestParseTransferEncoding(string transferEncoding, TransferCoding expectedTransferEncodingOptions) + { + var transferEncodingOptions = HttpHeaders.GetFinalTransferCoding(transferEncoding); + Assert.Equal(expectedTransferEncodingOptions, transferEncodingOptions); + } + + [Theory] + [InlineData("chunked", "gzip", TransferCoding.Other)] + [InlineData("compress", "chunked", TransferCoding.Chunked)] + [InlineData("chunked", "", TransferCoding.Chunked)] + [InlineData("", "chunked", TransferCoding.Chunked)] + [InlineData("chunked, deflate", "", TransferCoding.Other)] + [InlineData("gzip, chunked", "", TransferCoding.Chunked)] + [InlineData("", "chunked, compress", TransferCoding.Other)] + [InlineData("", "compress, chunked", TransferCoding.Chunked)] + [InlineData("", "", TransferCoding.None)] + [InlineData("deflate", "", TransferCoding.Other)] + [InlineData("", "gzip", TransferCoding.Other)] + public void TestParseTransferEncodingMultipleValues(string value1, string value2, TransferCoding expectedTransferEncodingOptions) + { + var transferEncoding = new StringValues(new[] { value1, value2 }); + var transferEncodingOptions = HttpHeaders.GetFinalTransferCoding(transferEncoding); + Assert.Equal(expectedTransferEncodingOptions, transferEncodingOptions); + } + + [Fact] + public void ValidContentLengthsAccepted() + { + ValidContentLengthsAcceptedImpl(new HttpRequestHeaders()); + ValidContentLengthsAcceptedImpl(new HttpResponseHeaders()); + } + + private static void ValidContentLengthsAcceptedImpl(HttpHeaders httpHeaders) + { + IDictionary headers = httpHeaders; + + StringValues value; + + Assert.False(headers.TryGetValue("Content-Length", out value)); + Assert.Null(httpHeaders.ContentLength); + Assert.False(httpHeaders.ContentLength.HasValue); + + httpHeaders.ContentLength = 1; + Assert.True(headers.TryGetValue("Content-Length", out value)); + Assert.Equal("1", value[0]); + Assert.Equal(1, httpHeaders.ContentLength); + Assert.True(httpHeaders.ContentLength.HasValue); + + httpHeaders.ContentLength = long.MaxValue; + Assert.True(headers.TryGetValue("Content-Length", out value)); + Assert.Equal(HeaderUtilities.FormatNonNegativeInt64(long.MaxValue), value[0]); + Assert.Equal(long.MaxValue, httpHeaders.ContentLength); + Assert.True(httpHeaders.ContentLength.HasValue); + + httpHeaders.ContentLength = null; + Assert.False(headers.TryGetValue("Content-Length", out value)); + Assert.Null(httpHeaders.ContentLength); + Assert.False(httpHeaders.ContentLength.HasValue); + } + + [Fact] + public void InvalidContentLengthsRejected() + { + InvalidContentLengthsRejectedImpl(new HttpRequestHeaders()); + InvalidContentLengthsRejectedImpl(new HttpResponseHeaders()); + } + + private static void InvalidContentLengthsRejectedImpl(HttpHeaders httpHeaders) + { + IDictionary headers = httpHeaders; + + StringValues value; + + Assert.False(headers.TryGetValue("Content-Length", out value)); + Assert.Null(httpHeaders.ContentLength); + Assert.False(httpHeaders.ContentLength.HasValue); + + Assert.Throws(() => httpHeaders.ContentLength = -1); + Assert.Throws(() => httpHeaders.ContentLength = long.MinValue); + + Assert.False(headers.TryGetValue("Content-Length", out value)); + Assert.Null(httpHeaders.ContentLength); + Assert.False(httpHeaders.ContentLength.HasValue); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpParserTests.cs b/src/Servers/Kestrel/Core/test/HttpParserTests.cs new file mode 100644 index 0000000000..f40c25669b --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpParserTests.cs @@ -0,0 +1,520 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpParserTests + { + private static IKestrelTrace _nullTrace = Mock.Of(); + + [Theory] + [MemberData(nameof(RequestLineValidData))] + public void ParsesRequestLine( + string requestLine, + string expectedMethod, + string expectedRawTarget, + string expectedRawPath, + // This warns that theory methods should use all of their parameters, + // but this method is using a shared data collection with Http1ConnectionTests.TakeStartLineSetsHttpProtocolProperties and others. +#pragma warning disable xUnit1026 + string expectedDecodedPath, + string expectedQueryString, +#pragma warning restore xUnit1026 + string expectedVersion) + { + var parser = CreateParser(_nullTrace); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + Assert.True(parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(requestHandler.Method, expectedMethod); + Assert.Equal(requestHandler.Version, expectedVersion); + Assert.Equal(requestHandler.RawTarget, expectedRawTarget); + Assert.Equal(requestHandler.RawPath, expectedRawPath); + Assert.Equal(requestHandler.Version, expectedVersion); + Assert.True(buffer.Slice(consumed).IsEmpty); + Assert.True(buffer.Slice(examined).IsEmpty); + } + + [Theory] + [MemberData(nameof(RequestLineIncompleteData))] + public void ParseRequestLineReturnsFalseWhenGivenIncompleteRequestLines(string requestLine) + { + var parser = CreateParser(_nullTrace); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + Assert.False(parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + } + + [Theory] + [MemberData(nameof(RequestLineIncompleteData))] + public void ParseRequestLineDoesNotConsumeIncompleteRequestLine(string requestLine) + { + var parser = CreateParser(_nullTrace); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + Assert.False(parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(buffer.Start, consumed); + Assert.True(buffer.Slice(examined).IsEmpty); + } + + [Theory] + [MemberData(nameof(RequestLineInvalidData))] + public void ParseRequestLineThrowsOnInvalidRequestLine(string requestLine) + { + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + + var parser = CreateParser(mockTrace.Object); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + var exception = Assert.Throws(() => + parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(requestLine.EscapeNonPrintable()), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, (exception as BadHttpRequestException).StatusCode); + } + + [Theory] + [MemberData(nameof(MethodWithNonTokenCharData))] + public void ParseRequestLineThrowsOnNonTokenCharsInCustomMethod(string method) + { + var requestLine = $"{method} / HTTP/1.1\r\n"; + + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + + var parser = CreateParser(mockTrace.Object); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + var exception = Assert.Throws(() => + parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(method.EscapeNonPrintable() + @" / HTTP/1.1\x0D\x0A"), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, (exception as BadHttpRequestException).StatusCode); + } + + [Theory] + [MemberData(nameof(UnrecognizedHttpVersionData))] + public void ParseRequestLineThrowsOnUnrecognizedHttpVersion(string httpVersion) + { + var requestLine = $"GET / {httpVersion}\r\n"; + + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + + var parser = CreateParser(mockTrace.Object); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(requestLine)); + var requestHandler = new RequestHandler(); + + var exception = Assert.Throws(() => + parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_UnrecognizedHTTPVersion(httpVersion), exception.Message); + Assert.Equal(StatusCodes.Status505HttpVersionNotsupported, (exception as BadHttpRequestException).StatusCode); + } + + [Theory] + [InlineData("\r")] + [InlineData("H")] + [InlineData("He")] + [InlineData("Hea")] + [InlineData("Head")] + [InlineData("Heade")] + [InlineData("Header")] + [InlineData("Header:")] + [InlineData("Header: ")] + [InlineData("Header: v")] + [InlineData("Header: va")] + [InlineData("Header: val")] + [InlineData("Header: valu")] + [InlineData("Header: value")] + [InlineData("Header: value\r")] + [InlineData("Header: value\r\n")] + [InlineData("Header: value\r\n\r")] + [InlineData("Header-1: value1\r\nH")] + [InlineData("Header-1: value1\r\nHe")] + [InlineData("Header-1: value1\r\nHea")] + [InlineData("Header-1: value1\r\nHead")] + [InlineData("Header-1: value1\r\nHeade")] + [InlineData("Header-1: value1\r\nHeader")] + [InlineData("Header-1: value1\r\nHeader-")] + [InlineData("Header-1: value1\r\nHeader-2")] + [InlineData("Header-1: value1\r\nHeader-2:")] + [InlineData("Header-1: value1\r\nHeader-2: ")] + [InlineData("Header-1: value1\r\nHeader-2: v")] + [InlineData("Header-1: value1\r\nHeader-2: va")] + [InlineData("Header-1: value1\r\nHeader-2: val")] + [InlineData("Header-1: value1\r\nHeader-2: valu")] + [InlineData("Header-1: value1\r\nHeader-2: value")] + [InlineData("Header-1: value1\r\nHeader-2: value2")] + [InlineData("Header-1: value1\r\nHeader-2: value2\r")] + [InlineData("Header-1: value1\r\nHeader-2: value2\r\n")] + [InlineData("Header-1: value1\r\nHeader-2: value2\r\n\r")] + public void ParseHeadersReturnsFalseWhenGivenIncompleteHeaders(string rawHeaders) + { + var parser = CreateParser(_nullTrace); + + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); + var requestHandler = new RequestHandler(); + Assert.False(parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes)); + } + + [Theory] + [InlineData("\r")] + [InlineData("H")] + [InlineData("He")] + [InlineData("Hea")] + [InlineData("Head")] + [InlineData("Heade")] + [InlineData("Header")] + [InlineData("Header:")] + [InlineData("Header: ")] + [InlineData("Header: v")] + [InlineData("Header: va")] + [InlineData("Header: val")] + [InlineData("Header: valu")] + [InlineData("Header: value")] + [InlineData("Header: value\r")] + public void ParseHeadersDoesNotConsumeIncompleteHeader(string rawHeaders) + { + var parser = CreateParser(_nullTrace); + + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); + var requestHandler = new RequestHandler(); + parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes); + + Assert.Equal(buffer.Length, buffer.Slice(consumed).Length); + Assert.True(buffer.Slice(examined).IsEmpty); + Assert.Equal(0, consumedBytes); + } + + [Fact] + public void ParseHeadersCanReadHeaderValueWithoutLeadingWhitespace() + { + VerifyHeader("Header", "value", "value"); + } + + [Theory] + [InlineData("Cookie: \r\n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie:\r\n\r\n", "Cookie", "", null, null)] + [InlineData("Cookie: \r\nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Cookie:\r\nConnection: close\r\n\r\n", "Cookie", "", "Connection", "close")] + [InlineData("Connection: close\r\nCookie: \r\n\r\n", "Connection", "close", "Cookie", "")] + [InlineData("Connection: close\r\nCookie:\r\n\r\n", "Connection", "close", "Cookie", "")] + public void ParseHeadersCanParseEmptyHeaderValues( + string rawHeaders, + string expectedHeaderName1, + string expectedHeaderValue1, + string expectedHeaderName2, + string expectedHeaderValue2) + { + var expectedHeaderNames = expectedHeaderName2 == null + ? new[] { expectedHeaderName1 } + : new[] { expectedHeaderName1, expectedHeaderName2 }; + var expectedHeaderValues = expectedHeaderValue2 == null + ? new[] { expectedHeaderValue1 } + : new[] { expectedHeaderValue1, expectedHeaderValue2 }; + + VerifyRawHeaders(rawHeaders, expectedHeaderNames, expectedHeaderValues); + } + + [Theory] + [InlineData(" value")] + [InlineData(" value")] + [InlineData("\tvalue")] + [InlineData(" \tvalue")] + [InlineData("\t value")] + [InlineData("\t\tvalue")] + [InlineData("\t\t value")] + [InlineData(" \t\tvalue")] + [InlineData(" \t\t value")] + [InlineData(" \t \t value")] + public void ParseHeadersDoesNotIncludeLeadingWhitespaceInHeaderValue(string rawHeaderValue) + { + VerifyHeader("Header", rawHeaderValue, "value"); + } + + [Theory] + [InlineData("value ")] + [InlineData("value\t")] + [InlineData("value \t")] + [InlineData("value\t ")] + [InlineData("value\t\t")] + [InlineData("value\t\t ")] + [InlineData("value \t\t")] + [InlineData("value \t\t ")] + [InlineData("value \t \t ")] + public void ParseHeadersDoesNotIncludeTrailingWhitespaceInHeaderValue(string rawHeaderValue) + { + VerifyHeader("Header", rawHeaderValue, "value"); + } + + [Theory] + [InlineData("one two three")] + [InlineData("one two three")] + [InlineData("one\ttwo\tthree")] + [InlineData("one two\tthree")] + [InlineData("one\ttwo three")] + [InlineData("one \ttwo \tthree")] + [InlineData("one\t two\t three")] + [InlineData("one \ttwo\t three")] + public void ParseHeadersPreservesWhitespaceWithinHeaderValue(string headerValue) + { + VerifyHeader("Header", headerValue, headerValue); + } + + [Fact] + public void ParseHeadersConsumesBytesCorrectlyAtEnd() + { + var parser = CreateParser(_nullTrace); + + const string headerLine = "Header: value\r\n\r"; + var buffer1 = new ReadOnlySequence(Encoding.ASCII.GetBytes(headerLine)); + var requestHandler = new RequestHandler(); + Assert.False(parser.ParseHeaders(requestHandler, buffer1, out var consumed, out var examined, out var consumedBytes)); + + Assert.Equal(buffer1.GetPosition(headerLine.Length - 1), consumed); + Assert.Equal(buffer1.End, examined); + Assert.Equal(headerLine.Length - 1, consumedBytes); + + var buffer2 = new ReadOnlySequence(Encoding.ASCII.GetBytes("\r\n")); + Assert.True(parser.ParseHeaders(requestHandler, buffer2, out consumed, out examined, out consumedBytes)); + + Assert.True(buffer2.Slice(consumed).IsEmpty); + Assert.True(buffer2.Slice(examined).IsEmpty); + Assert.Equal(2, consumedBytes); + } + + [Theory] + [MemberData(nameof(RequestHeaderInvalidData))] + public void ParseHeadersThrowsOnInvalidRequestHeaders(string rawHeaders, string expectedExceptionMessage) + { + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + + var parser = CreateParser(mockTrace.Object); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); + var requestHandler = new RequestHandler(); + + var exception = Assert.Throws(() => + parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes)); + + Assert.Equal(expectedExceptionMessage, exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + + [Fact] + public void ExceptionDetailNotIncludedWhenLogLevelInformationNotEnabled() + { + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(false); + + var parser = CreateParser(mockTrace.Object); + + // Invalid request line + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes("GET % HTTP/1.1\r\n")); + var requestHandler = new RequestHandler(); + + var exception = Assert.Throws(() => + parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal("Invalid request line: ''", exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, (exception as BadHttpRequestException).StatusCode); + + // Unrecognized HTTP version + buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes("GET / HTTP/1.2\r\n")); + + exception = Assert.Throws(() => + parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined)); + + Assert.Equal(CoreStrings.FormatBadRequest_UnrecognizedHTTPVersion(string.Empty), exception.Message); + Assert.Equal(StatusCodes.Status505HttpVersionNotsupported, (exception as BadHttpRequestException).StatusCode); + + // Invalid request header + buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes("Header: value\n\r\n")); + + exception = Assert.Throws(() => + parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes)); + + Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(string.Empty), exception.Message); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + + [Fact] + public void ParseRequestLineSplitBufferWithoutNewLineDoesNotUpdateConsumed() + { + var parser = CreateParser(_nullTrace); + var buffer = ReadOnlySequenceFactory.CreateSegments( + Encoding.ASCII.GetBytes("GET "), + Encoding.ASCII.GetBytes("/")); + + var requestHandler = new RequestHandler(); + var result = parser.ParseRequestLine(requestHandler, buffer, out var consumed, out var examined); + + Assert.False(result); + Assert.Equal(buffer.Start, consumed); + Assert.Equal(buffer.End, examined); + } + + [Fact] + public void ParseHeadersWithGratuitouslySplitBuffers() + { + var parser = CreateParser(_nullTrace); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent("Host:\r\nConnection: keep-alive\r\n\r\n"); + + var requestHandler = new RequestHandler(); + var result = parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out _); + + Assert.True(result); + } + + [Fact] + public void ParseHeadersWithGratuitouslySplitBuffers2() + { + var parser = CreateParser(_nullTrace); + var buffer = BytePerSegmentTestSequenceFactory.Instance.CreateWithContent("A:B\r\nB: C\r\n\r\n"); + + var requestHandler = new RequestHandler(); + var result = parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out _); + + Assert.True(result); + } + + private void VerifyHeader( + string headerName, + string rawHeaderValue, + string expectedHeaderValue) + { + var parser = CreateParser(_nullTrace); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes($"{headerName}:{rawHeaderValue}\r\n")); + + var requestHandler = new RequestHandler(); + parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes); + + var pairs = requestHandler.Headers.ToArray(); + Assert.Single(pairs); + Assert.Equal(headerName, pairs[0].Key); + Assert.Equal(expectedHeaderValue, pairs[0].Value); + Assert.True(buffer.Slice(consumed).IsEmpty); + Assert.True(buffer.Slice(examined).IsEmpty); + } + + private void VerifyRawHeaders(string rawHeaders, IEnumerable expectedHeaderNames, IEnumerable expectedHeaderValues) + { + Assert.True(expectedHeaderNames.Count() == expectedHeaderValues.Count(), $"{nameof(expectedHeaderNames)} and {nameof(expectedHeaderValues)} sizes must match"); + + var parser = CreateParser(_nullTrace); + var buffer = new ReadOnlySequence(Encoding.ASCII.GetBytes(rawHeaders)); + + var requestHandler = new RequestHandler(); + parser.ParseHeaders(requestHandler, buffer, out var consumed, out var examined, out var consumedBytes); + + var parsedHeaders = requestHandler.Headers.ToArray(); + + Assert.Equal(expectedHeaderNames.Count(), parsedHeaders.Length); + Assert.Equal(expectedHeaderNames, parsedHeaders.Select(t => t.Key)); + Assert.Equal(expectedHeaderValues, parsedHeaders.Select(t => t.Value)); + Assert.True(buffer.Slice(consumed).IsEmpty); + Assert.True(buffer.Slice(examined).IsEmpty); + } + + private IHttpParser CreateParser(IKestrelTrace log) => new HttpParser(log.IsEnabled(LogLevel.Information)); + + public static IEnumerable RequestLineValidData => HttpParsingData.RequestLineValidData; + + public static IEnumerable RequestLineIncompleteData => HttpParsingData.RequestLineIncompleteData.Select(requestLine => new[] { requestLine }); + + public static IEnumerable RequestLineInvalidData => HttpParsingData.RequestLineInvalidData.Select(requestLine => new[] { requestLine }); + + public static IEnumerable MethodWithNonTokenCharData => HttpParsingData.MethodWithNonTokenCharData.Select(method => new[] { method }); + + public static TheoryData UnrecognizedHttpVersionData => HttpParsingData.UnrecognizedHttpVersionData; + + public static IEnumerable RequestHeaderInvalidData => HttpParsingData.RequestHeaderInvalidData; + + private class RequestHandler : IHttpRequestLineHandler, IHttpHeadersHandler + { + public string Method { get; set; } + + public string Version { get; set; } + + public string RawTarget { get; set; } + + public string RawPath { get; set; } + + public string Query { get; set; } + + public bool PathEncoded { get; set; } + + public Dictionary Headers { get; } = new Dictionary(); + + public void OnHeader(Span name, Span value) + { + Headers[name.GetAsciiStringNonNullCharacters()] = value.GetAsciiStringNonNullCharacters(); + } + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + Method = method != HttpMethod.Custom ? HttpUtilities.MethodToString(method) : customMethod.GetAsciiStringNonNullCharacters(); + Version = HttpUtilities.VersionToString(version); + RawTarget = target.GetAsciiStringNonNullCharacters(); + RawPath = path.GetAsciiStringNonNullCharacters(); + Query = query.GetAsciiStringNonNullCharacters(); + PathEncoded = pathEncoded; + } + } + + // Doesn't put empty blocks inbetween every byte + internal class BytePerSegmentTestSequenceFactory : ReadOnlySequenceFactory + { + public static ReadOnlySequenceFactory Instance { get; } = new HttpParserTests.BytePerSegmentTestSequenceFactory(); + + public override ReadOnlySequence CreateOfSize(int size) + { + return CreateWithContent(new byte[size]); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + var segments = new List(); + + foreach (var b in data) + { + segments.Add(new[] { b }); + } + + return CreateSegments(segments.ToArray()); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs new file mode 100644 index 0000000000..84d08ff3d7 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpProtocolFeatureCollectionTests.cs @@ -0,0 +1,226 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Linq; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpProtocolFeatureCollectionTests : IDisposable + { + private readonly IDuplexPipe _transport; + private readonly IDuplexPipe _application; + private readonly TestHttp1Connection _http1Connection; + private readonly ServiceContext _serviceContext; + private readonly Http1ConnectionContext _http1ConnectionContext; + private readonly MemoryPool _memoryPool; + private Mock _timeoutControl; + + private readonly IFeatureCollection _collection; + + public HttpProtocolFeatureCollectionTests() + { + _memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + _transport = pair.Transport; + _application = pair.Application; + + _serviceContext = new TestServiceContext(); + _timeoutControl = new Mock(); + _http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = _serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = _memoryPool, + TimeoutControl = _timeoutControl.Object, + Application = pair.Application, + Transport = pair.Transport + }; + + _http1Connection = new TestHttp1Connection(_http1ConnectionContext); + _http1Connection.Reset(); + _collection = _http1Connection; + } + + public void Dispose() + { + _transport.Input.Complete(); + _transport.Output.Complete(); + + _application.Input.Complete(); + _application.Output.Complete(); + + _memoryPool.Dispose(); + } + + [Fact] + public int FeaturesStartAsSelf() + { + var featureCount = 0; + foreach (var featureIter in _collection) + { + Type type = featureIter.Key; + if (type.IsAssignableFrom(typeof(HttpProtocol))) + { + var featureLookup = _collection[type]; + Assert.Same(featureLookup, featureIter.Value); + Assert.Same(featureLookup, _collection); + featureCount++; + } + } + + Assert.NotEqual(0, featureCount); + + return featureCount; + } + + [Fact] + public int FeaturesCanBeAssignedTo() + { + var featureCount = SetFeaturesToNonDefault(); + Assert.NotEqual(0, featureCount); + + featureCount = 0; + foreach (var feature in _collection) + { + Type type = feature.Key; + if (type.IsAssignableFrom(typeof(HttpProtocol))) + { + Assert.Same(_collection[type], feature.Value); + Assert.NotSame(_collection[type], _collection); + featureCount++; + } + } + + Assert.NotEqual(0, featureCount); + + return featureCount; + } + + [Fact] + public void FeaturesResetToSelf() + { + var featuresAssigned = SetFeaturesToNonDefault(); + _http1Connection.ResetFeatureCollection(); + var featuresReset = FeaturesStartAsSelf(); + + Assert.Equal(featuresAssigned, featuresReset); + } + + [Fact] + public void FeaturesByGenericSameAsByType() + { + var featuresAssigned = SetFeaturesToNonDefault(); + + CompareGenericGetterToIndexer(); + + _http1Connection.ResetFeatureCollection(); + var featuresReset = FeaturesStartAsSelf(); + + Assert.Equal(featuresAssigned, featuresReset); + } + + [Fact] + public void FeaturesSetByTypeSameAsGeneric() + { + _collection[typeof(IHttpRequestFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpResponseFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpRequestIdentifierFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpRequestLifetimeFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpConnectionFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpMaxRequestBodySizeFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpMinRequestBodyDataRateFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpMinResponseDataRateFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpBodyControlFeature)] = CreateHttp1Connection(); + + CompareGenericGetterToIndexer(); + + EachHttpProtocolFeatureSetAndUnique(); + } + + [Fact] + public void FeaturesSetByGenericSameAsByType() + { + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + + CompareGenericGetterToIndexer(); + + EachHttpProtocolFeatureSetAndUnique(); + } + + private void CompareGenericGetterToIndexer() + { + Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpResponseFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestIdentifierFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpRequestLifetimeFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpConnectionFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpMaxRequestBodySizeFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpMinRequestBodyDataRateFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpMinResponseDataRateFeature)]); + Assert.Same(_collection.Get(), _collection[typeof(IHttpBodyControlFeature)]); + } + + private int EachHttpProtocolFeatureSetAndUnique() + { + int featureCount = 0; + foreach (var item in _collection) + { + Type type = item.Key; + if (type.IsAssignableFrom(typeof(HttpProtocol))) + { + Assert.Equal(1, _collection.Count(kv => ReferenceEquals(kv.Value, item.Value))); + + featureCount++; + } + } + + Assert.NotEqual(0, featureCount); + + return featureCount; + } + + private int SetFeaturesToNonDefault() + { + int featureCount = 0; + foreach (var feature in _collection) + { + Type type = feature.Key; + if (type.IsAssignableFrom(typeof(HttpProtocol))) + { + _collection[type] = CreateHttp1Connection(); + featureCount++; + } + } + + var protocolFeaturesCount = EachHttpProtocolFeatureSetAndUnique(); + + Assert.Equal(protocolFeaturesCount, featureCount); + + return featureCount; + } + + private HttpProtocol CreateHttp1Connection() => new TestHttp1Connection(_http1ConnectionContext); + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs b/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs new file mode 100644 index 0000000000..1eef7c792c --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs @@ -0,0 +1,314 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpRequestHeadersTests + { + [Fact] + public void InitialDictionaryIsEmpty() + { + IDictionary headers = new HttpRequestHeaders(); + + Assert.Equal(0, headers.Count); + Assert.False(headers.IsReadOnly); + } + + [Fact] + public void SettingUnknownHeadersWorks() + { + IDictionary headers = new HttpRequestHeaders(); + + headers["custom"] = new[] { "value" }; + + var header = Assert.Single(headers["custom"]); + Assert.Equal("value", header); + } + + [Fact] + public void SettingKnownHeadersWorks() + { + IDictionary headers = new HttpRequestHeaders(); + + headers["host"] = new[] { "value" }; + headers["content-length"] = new[] { "0" }; + + var host = Assert.Single(headers["host"]); + var contentLength = Assert.Single(headers["content-length"]); + Assert.Equal("value", host); + Assert.Equal("0", contentLength); + } + + [Fact] + public void KnownAndCustomHeaderCountAddedTogether() + { + IDictionary headers = new HttpRequestHeaders(); + + headers["host"] = new[] { "value" }; + headers["custom"] = new[] { "value" }; + headers["Content-Length"] = new[] { "0" }; + + Assert.Equal(3, headers.Count); + } + + [Fact] + public void TryGetValueWorksForKnownAndUnknownHeaders() + { + IDictionary headers = new HttpRequestHeaders(); + + StringValues value; + Assert.False(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + + headers["host"] = new[] { "value" }; + Assert.True(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + + headers["custom"] = new[] { "value" }; + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + + headers["Content-Length"] = new[] { "0" }; + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + } + + [Fact] + public void SameExceptionThrownForMissingKey() + { + IDictionary headers = new HttpRequestHeaders(); + + Assert.Throws(() => headers["custom"]); + Assert.Throws(() => headers["host"]); + Assert.Throws(() => headers["Content-Length"]); + } + + [Fact] + public void EntriesCanBeEnumerated() + { + IDictionary headers = new HttpRequestHeaders(); + var v1 = new[] { "localhost" }; + var v2 = new[] { "0" }; + var v3 = new[] { "value" }; + headers["host"] = v1; + headers["Content-Length"] = v2; + headers["custom"] = v3; + + Assert.Equal( + new[] { + new KeyValuePair("Host", v1), + new KeyValuePair("Content-Length", v2), + new KeyValuePair("custom", v3), + }, + headers); + } + + [Fact] + public void KeysAndValuesCanBeEnumerated() + { + IDictionary headers = new HttpRequestHeaders(); + StringValues v1 = new[] { "localhost" }; + StringValues v2 = new[] { "0" }; + StringValues v3 = new[] { "value" }; + headers["host"] = v1; + headers["Content-Length"] = v2; + headers["custom"] = v3; + + Assert.Equal( + new[] { "Host", "Content-Length", "custom" }, + headers.Keys); + + Assert.Equal( + new[] { v1, v2, v3 }, + headers.Values); + } + + [Fact] + public void ContainsAndContainsKeyWork() + { + IDictionary headers = new HttpRequestHeaders(); + var kv1 = new KeyValuePair("host", new[] { "localhost" }); + var kv2 = new KeyValuePair("custom", new[] { "value" }); + var kv3 = new KeyValuePair("Content-Length", new[] { "0" }); + var kv1b = new KeyValuePair("host", new[] { "not-localhost" }); + var kv2b = new KeyValuePair("custom", new[] { "not-value" }); + var kv3b = new KeyValuePair("Content-Length", new[] { "1" }); + + Assert.False(headers.ContainsKey("host")); + Assert.False(headers.ContainsKey("custom")); + Assert.False(headers.ContainsKey("Content-Length")); + Assert.False(headers.Contains(kv1)); + Assert.False(headers.Contains(kv2)); + Assert.False(headers.Contains(kv3)); + + headers["host"] = kv1.Value; + Assert.True(headers.ContainsKey("host")); + Assert.False(headers.ContainsKey("custom")); + Assert.False(headers.ContainsKey("Content-Length")); + Assert.True(headers.Contains(kv1)); + Assert.False(headers.Contains(kv2)); + Assert.False(headers.Contains(kv3)); + Assert.False(headers.Contains(kv1b)); + Assert.False(headers.Contains(kv2b)); + Assert.False(headers.Contains(kv3b)); + + headers["custom"] = kv2.Value; + Assert.True(headers.ContainsKey("host")); + Assert.True(headers.ContainsKey("custom")); + Assert.False(headers.ContainsKey("Content-Length")); + Assert.True(headers.Contains(kv1)); + Assert.True(headers.Contains(kv2)); + Assert.False(headers.Contains(kv3)); + Assert.False(headers.Contains(kv1b)); + Assert.False(headers.Contains(kv2b)); + Assert.False(headers.Contains(kv3b)); + + headers["Content-Length"] = kv3.Value; + Assert.True(headers.ContainsKey("host")); + Assert.True(headers.ContainsKey("custom")); + Assert.True(headers.ContainsKey("Content-Length")); + Assert.True(headers.Contains(kv1)); + Assert.True(headers.Contains(kv2)); + Assert.True(headers.Contains(kv3)); + Assert.False(headers.Contains(kv1b)); + Assert.False(headers.Contains(kv2b)); + Assert.False(headers.Contains(kv3b)); + } + + [Fact] + public void AddWorksLikeSetAndThrowsIfKeyExists() + { + IDictionary headers = new HttpRequestHeaders(); + + StringValues value; + Assert.False(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + + headers.Add("host", new[] { "localhost" }); + headers.Add("custom", new[] { "value" }); + headers.Add("Content-Length", new[] { "0" }); + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + + Assert.Throws(() => headers.Add("host", new[] { "localhost" })); + Assert.Throws(() => headers.Add("custom", new[] { "value" })); + Assert.Throws(() => headers.Add("Content-Length", new[] { "0" })); + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + } + + [Fact] + public void ClearRemovesAllHeaders() + { + IDictionary headers = new HttpRequestHeaders(); + headers.Add("host", new[] { "localhost" }); + headers.Add("custom", new[] { "value" }); + headers.Add("Content-Length", new[] { "0" }); + + StringValues value; + Assert.Equal(3, headers.Count); + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + + headers.Clear(); + + Assert.Equal(0, headers.Count); + Assert.False(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + } + + [Fact] + public void RemoveTakesHeadersOutOfDictionary() + { + IDictionary headers = new HttpRequestHeaders(); + headers.Add("host", new[] { "localhost" }); + headers.Add("custom", new[] { "value" }); + headers.Add("Content-Length", new[] { "0" }); + + StringValues value; + Assert.Equal(3, headers.Count); + Assert.True(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + + Assert.True(headers.Remove("host")); + Assert.False(headers.Remove("host")); + + Assert.Equal(2, headers.Count); + Assert.False(headers.TryGetValue("host", out value)); + Assert.True(headers.TryGetValue("custom", out value)); + + Assert.True(headers.Remove("custom")); + Assert.False(headers.Remove("custom")); + + Assert.Equal(1, headers.Count); + Assert.False(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.True(headers.TryGetValue("Content-Length", out value)); + + Assert.True(headers.Remove("Content-Length")); + Assert.False(headers.Remove("Content-Length")); + + Assert.Equal(0, headers.Count); + Assert.False(headers.TryGetValue("host", out value)); + Assert.False(headers.TryGetValue("custom", out value)); + Assert.False(headers.TryGetValue("Content-Length", out value)); + } + + [Fact] + public void CopyToMovesDataIntoArray() + { + IDictionary headers = new HttpRequestHeaders(); + headers.Add("host", new[] { "localhost" }); + headers.Add("Content-Length", new[] { "0" }); + headers.Add("custom", new[] { "value" }); + + var entries = new KeyValuePair[5]; + headers.CopyTo(entries, 1); + + Assert.Null(entries[0].Key); + Assert.Equal(new StringValues(), entries[0].Value); + + Assert.Equal("Host", entries[1].Key); + Assert.Equal(new[] { "localhost" }, entries[1].Value); + + Assert.Equal("Content-Length", entries[2].Key); + Assert.Equal(new[] { "0" }, entries[2].Value); + + Assert.Equal("custom", entries[3].Key); + Assert.Equal(new[] { "value" }, entries[3].Value); + + Assert.Null(entries[4].Key); + Assert.Equal(new StringValues(), entries[4].Value); + } + + [Fact] + public void AppendThrowsWhenHeaderNameContainsNonASCIICharacters() + { + var headers = new HttpRequestHeaders(); + const string key = "\u00141\u00F3d\017c"; + + var encoding = Encoding.GetEncoding("iso-8859-1"); + var exception = Assert.Throws( + () => headers.Append(encoding.GetBytes(key), "value")); + Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs b/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs new file mode 100644 index 0000000000..1ffa74b027 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpRequestStreamTests.cs @@ -0,0 +1,213 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpRequestStreamTests + { + [Fact] + public void CanReadReturnsTrue() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.True(stream.CanRead); + } + + [Fact] + public void CanSeekReturnsFalse() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.False(stream.CanSeek); + } + + [Fact] + public void CanWriteReturnsFalse() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.False(stream.CanWrite); + } + + [Fact] + public void SeekThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void LengthThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.Length); + } + + [Fact] + public void SetLengthThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.SetLength(0)); + } + + [Fact] + public void PositionThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.Position); + Assert.Throws(() => stream.Position = 0); + } + + [Fact] + public void WriteThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.Write(new byte[1], 0, 1)); + } + + [Fact] + public void WriteByteThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.WriteByte(0)); + } + + [Fact] + public async Task WriteAsyncThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], 0, 1)); + } + +#if NET461 + [Fact] + public void BeginWriteThrows() + { + var stream = new HttpRequestStream(Mock.Of()); + Assert.Throws(() => stream.BeginWrite(new byte[1], 0, 1, null, null)); + } +#elif NETCOREAPP2_1 +#else +#error Target framework needs to be updated +#endif + + [Fact] + // Read-only streams should support Flush according to https://github.com/dotnet/corefx/pull/27327#pullrequestreview-98384813 + public void FlushDoesNotThrow() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.Flush(); + } + + [Fact] + public async Task FlushAsyncDoesNotThrow() + { + var stream = new HttpRequestStream(Mock.Of()); + await stream.FlushAsync(); + } + + [Fact] + public async Task SynchronousReadsThrowIfDisallowedByIHttpBodyControlFeature() + { + var allowSynchronousIO = false; + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(() => allowSynchronousIO); + var mockMessageBody = new Mock((HttpProtocol)null); + mockMessageBody.Setup(m => m.ReadAsync(It.IsAny>(), CancellationToken.None)).Returns(new ValueTask(0)); + + var stream = new HttpRequestStream(mockBodyControl.Object); + stream.StartAcceptingReads(mockMessageBody.Object); + + Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1)); + + var ioEx = Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + Assert.Equal("Synchronous operations are disallowed. Call ReadAsync or set AllowSynchronousIO to true instead.", ioEx.Message); + + var ioEx2 = Assert.Throws(() => stream.CopyTo(Stream.Null)); + Assert.Equal("Synchronous operations are disallowed. Call ReadAsync or set AllowSynchronousIO to true instead.", ioEx2.Message); + + allowSynchronousIO = true; + Assert.Equal(0, stream.Read(new byte[1], 0, 1)); + } + + [Fact] + public async Task AbortCausesReadToCancel() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + stream.Abort(); + await Assert.ThrowsAsync(() => stream.ReadAsync(new byte[1], 0, 1)); + } + + [Fact] + public async Task AbortWithErrorCausesReadToCancel() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + var error = new Exception(); + stream.Abort(error); + var exception = await Assert.ThrowsAsync(() => stream.ReadAsync(new byte[1], 0, 1)); + Assert.Same(error, exception); + } + + [Fact] + public void StopAcceptingReadsCausesReadToThrowObjectDisposedException() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + stream.StopAcceptingReads(); + Assert.Throws(() => { stream.ReadAsync(new byte[1], 0, 1); }); + } + + [Fact] + public async Task AbortCausesCopyToAsyncToCancel() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + stream.Abort(); + await Assert.ThrowsAsync(() => stream.CopyToAsync(Mock.Of())); + } + + [Fact] + public async Task AbortWithErrorCausesCopyToAsyncToCancel() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + var error = new Exception(); + stream.Abort(error); + var exception = await Assert.ThrowsAsync(() => stream.CopyToAsync(Mock.Of())); + Assert.Same(error, exception); + } + + [Fact] + public void StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + stream.StopAcceptingReads(); + Assert.Throws(() => { stream.CopyToAsync(Mock.Of()); }); + } + + [Fact] + public void NullDestinationCausesCopyToAsyncToThrowArgumentNullException() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + Assert.Throws(() => { stream.CopyToAsync(null); }); + } + + [Fact] + public void ZeroBufferSizeCausesCopyToAsyncToThrowArgumentException() + { + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(null); + Assert.Throws(() => { stream.CopyToAsync(Mock.Of(), 0); }); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpResponseHeadersTests.cs b/src/Servers/Kestrel/Core/test/HttpResponseHeadersTests.cs new file mode 100644 index 0000000000..a8e2e7f441 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpResponseHeadersTests.cs @@ -0,0 +1,272 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Globalization; +using System.IO.Pipelines; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpResponseHeadersTests + { + [Fact] + public void InitialDictionaryIsEmpty() + { + using (var memoryPool = KestrelMemoryPool.Create()) + { + var options = new PipeOptions(memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + var http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = new TestServiceContext(), + ConnectionFeatures = new FeatureCollection(), + MemoryPool = memoryPool, + Application = pair.Application, + Transport = pair.Transport, + TimeoutControl = null + }; + + var http1Connection = new Http1Connection(http1ConnectionContext); + + http1Connection.Reset(); + + IDictionary headers = http1Connection.ResponseHeaders; + + Assert.Equal(0, headers.Count); + Assert.False(headers.IsReadOnly); + } + } + + [Theory] + [InlineData("Server", "\r\nData")] + [InlineData("Server", "\0Data")] + [InlineData("Server", "Data\r")] + [InlineData("Server", "Da\0ta")] + [InlineData("Server", "Da\u001Fta")] + [InlineData("Unknown-Header", "\r\nData")] + [InlineData("Unknown-Header", "\0Data")] + [InlineData("Unknown-Header", "Data\0")] + [InlineData("Unknown-Header", "Da\nta")] + [InlineData("\r\nServer", "Data")] + [InlineData("Server\r", "Data")] + [InlineData("Ser\0ver", "Data")] + [InlineData("Server\r\n", "Data")] + [InlineData("\u0000Server", "Data")] + [InlineData("Server", "Data\u0000")] + [InlineData("\u001FServer", "Data")] + [InlineData("Unknown-Header\r\n", "Data")] + [InlineData("\0Unknown-Header", "Data")] + [InlineData("Unknown\r-Header", "Data")] + [InlineData("Unk\nown-Header", "Data")] + [InlineData("Server", "Da\u007Fta")] + [InlineData("Unknown\u007F-Header", "Data")] + [InlineData("Ser\u0080ver", "Data")] + [InlineData("Server", "Da\u0080ta")] + [InlineData("Unknown\u0080-Header", "Data")] + [InlineData("Ser™ver", "Data")] + [InlineData("Server", "Da™ta")] + [InlineData("Unknown™-Header", "Data")] + [InlineData("šerver", "Data")] + [InlineData("Server", "Dašta")] + [InlineData("Unknownš-Header", "Data")] + [InlineData("Seršver", "Data")] + public void AddingControlOrNonAsciiCharactersToHeadersThrows(string key, string value) + { + var responseHeaders = new HttpResponseHeaders(); + + Assert.Throws(() => + { + ((IHeaderDictionary)responseHeaders)[key] = value; + }); + + Assert.Throws(() => + { + ((IHeaderDictionary)responseHeaders)[key] = new StringValues(new[] { "valid", value }); + }); + + Assert.Throws(() => + { + ((IDictionary)responseHeaders)[key] = value; + }); + + Assert.Throws(() => + { + var kvp = new KeyValuePair(key, value); + ((ICollection>)responseHeaders).Add(kvp); + }); + + Assert.Throws(() => + { + var kvp = new KeyValuePair(key, value); + ((IDictionary)responseHeaders).Add(key, value); + }); + } + + [Fact] + public void ThrowsWhenAddingHeaderAfterReadOnlyIsSet() + { + var headers = new HttpResponseHeaders(); + headers.SetReadOnly(); + + Assert.Throws(() => ((IDictionary)headers).Add("my-header", new[] { "value" })); + } + + [Fact] + public void ThrowsWhenChangingHeaderAfterReadOnlyIsSet() + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("my-header", new[] { "value" }); + headers.SetReadOnly(); + + Assert.Throws(() => dictionary["my-header"] = "other-value"); + } + + [Fact] + public void ThrowsWhenRemovingHeaderAfterReadOnlyIsSet() + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("my-header", new[] { "value" }); + headers.SetReadOnly(); + + Assert.Throws(() => dictionary.Remove("my-header")); + } + + [Fact] + public void ThrowsWhenClearingHeadersAfterReadOnlyIsSet() + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("my-header", new[] { "value" }); + headers.SetReadOnly(); + + Assert.Throws(() => dictionary.Clear()); + } + + [Theory] + [MemberData(nameof(BadContentLengths))] + public void ThrowsWhenAddingContentLengthWithNonNumericValue(string contentLength) + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + + var exception = Assert.Throws(() => dictionary.Add("Content-Length", new[] { contentLength })); + Assert.Equal(CoreStrings.FormatInvalidContentLength_InvalidNumber(contentLength), exception.Message); + } + + [Theory] + [MemberData(nameof(BadContentLengths))] + public void ThrowsWhenSettingContentLengthToNonNumericValue(string contentLength) + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + + var exception = Assert.Throws(() => ((IHeaderDictionary)headers)["Content-Length"] = contentLength); + Assert.Equal(CoreStrings.FormatInvalidContentLength_InvalidNumber(contentLength), exception.Message); + } + + [Theory] + [MemberData(nameof(BadContentLengths))] + public void ThrowsWhenAssigningHeaderContentLengthToNonNumericValue(string contentLength) + { + var headers = new HttpResponseHeaders(); + + var exception = Assert.Throws(() => headers.HeaderContentLength = contentLength); + Assert.Equal(CoreStrings.FormatInvalidContentLength_InvalidNumber(contentLength), exception.Message); + } + + [Theory] + [MemberData(nameof(GoodContentLengths))] + public void ContentLengthValueCanBeReadAsLongAfterAddingHeader(string contentLength) + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("Content-Length", contentLength); + + Assert.Equal(ParseLong(contentLength), headers.ContentLength); + } + + [Theory] + [MemberData(nameof(GoodContentLengths))] + public void ContentLengthValueCanBeReadAsLongAfterSettingHeader(string contentLength) + { + var headers = new HttpResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary["Content-Length"] = contentLength; + + Assert.Equal(ParseLong(contentLength), headers.ContentLength); + } + + [Theory] + [MemberData(nameof(GoodContentLengths))] + public void ContentLengthValueCanBeReadAsLongAfterAssigningHeader(string contentLength) + { + var headers = new HttpResponseHeaders(); + headers.HeaderContentLength = contentLength; + + Assert.Equal(ParseLong(contentLength), headers.ContentLength); + } + + [Fact] + public void ContentLengthValueClearedWhenHeaderIsRemoved() + { + var headers = new HttpResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Remove("Content-Length"); + + Assert.Null(headers.ContentLength); + } + + [Fact] + public void ContentLengthValueClearedWhenHeadersCleared() + { + var headers = new HttpResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Clear(); + + Assert.Null(headers.ContentLength); + } + + private static long ParseLong(string value) + { + return long.Parse(value, NumberStyles.AllowLeadingWhite | NumberStyles.AllowTrailingWhite, CultureInfo.InvariantCulture); + } + + public static TheoryData GoodContentLengths => new TheoryData + { + "0", + "00", + "042", + "42", + long.MaxValue.ToString(CultureInfo.InvariantCulture) + }; + + public static TheoryData BadContentLengths => new TheoryData + { + "", + " ", + " 42", + "42 ", + "bad", + "!", + "!42", + "42!", + "42,000", + "42.000", + }; + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs b/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs new file mode 100644 index 0000000000..5e49e8a5a2 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpResponseStreamTests.cs @@ -0,0 +1,129 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpResponseStreamTests + { + [Fact] + public void CanReadReturnsFalse() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.False(stream.CanRead); + } + + [Fact] + public void CanSeekReturnsFalse() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.False(stream.CanSeek); + } + + [Fact] + public void CanWriteReturnsTrue() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.True(stream.CanWrite); + } + + [Fact] + public void ReadThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + } + + [Fact] + public void ReadByteThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.ReadByte()); + } + + [Fact] + public async Task ReadAsyncThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + await Assert.ThrowsAsync(() => stream.ReadAsync(new byte[1], 0, 1)); + } + + [Fact] + public void BeginReadThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.BeginRead(new byte[1], 0, 1, null, null)); + } + + [Fact] + public void SeekThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void LengthThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.Length); + } + + [Fact] + public void SetLengthThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.SetLength(0)); + } + + [Fact] + public void PositionThrows() + { + var stream = new HttpResponseStream(Mock.Of(), new MockHttpResponseControl()); + Assert.Throws(() => stream.Position); + Assert.Throws(() => stream.Position = 0); + } + + [Fact] + public void StopAcceptingWritesCausesWriteToThrowObjectDisposedException() + { + var stream = new HttpResponseStream(Mock.Of(), Mock.Of()); + stream.StartAcceptingWrites(); + stream.StopAcceptingWrites(); + var ex = Assert.Throws(() => { stream.WriteAsync(new byte[1], 0, 1); }); + Assert.Contains(CoreStrings.WritingToResponseBodyAfterResponseCompleted, ex.Message); + } + + [Fact] + public async Task SynchronousWritesThrowIfDisallowedByIHttpBodyControlFeature() + { + var allowSynchronousIO = false; + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(() => allowSynchronousIO); + var mockHttpResponseControl = new Mock(); + mockHttpResponseControl.Setup(m => m.WriteAsync(It.IsAny>(), CancellationToken.None)).Returns(Task.CompletedTask); + + var stream = new HttpResponseStream(mockBodyControl.Object, mockHttpResponseControl.Object); + stream.StartAcceptingWrites(); + + // WriteAsync doesn't throw. + await stream.WriteAsync(new byte[1], 0, 1); + + var ioEx = Assert.Throws(() => stream.Write(new byte[1], 0, 1)); + Assert.Equal("Synchronous operations are disallowed. Call WriteAsync or set AllowSynchronousIO to true instead.", ioEx.Message); + + allowSynchronousIO = true; + // If IHttpBodyControlFeature.AllowSynchronousIO is true, Write no longer throws. + stream.Write(new byte[1], 0, 1); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/HttpUtilitiesTest.cs b/src/Servers/Kestrel/Core/test/HttpUtilitiesTest.cs new file mode 100644 index 0000000000..b503eab04e --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HttpUtilitiesTest.cs @@ -0,0 +1,231 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HttpUtilitiesTest + { + [Theory] + [InlineData("CONNECT / HTTP/1.1", true, "CONNECT", HttpMethod.Connect)] + [InlineData("DELETE / HTTP/1.1", true, "DELETE", HttpMethod.Delete)] + [InlineData("GET / HTTP/1.1", true, "GET", HttpMethod.Get)] + [InlineData("HEAD / HTTP/1.1", true, "HEAD", HttpMethod.Head)] + [InlineData("PATCH / HTTP/1.1", true, "PATCH", HttpMethod.Patch)] + [InlineData("POST / HTTP/1.1", true, "POST", HttpMethod.Post)] + [InlineData("PUT / HTTP/1.1", true, "PUT", HttpMethod.Put)] + [InlineData("OPTIONS / HTTP/1.1", true, "OPTIONS", HttpMethod.Options)] + [InlineData("TRACE / HTTP/1.1", true, "TRACE", HttpMethod.Trace)] + [InlineData("GET/ HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("get / HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("GOT / HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("ABC / HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("PO / HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("PO ST / HTTP/1.1", false, null, HttpMethod.Custom)] + [InlineData("short ", false, null, HttpMethod.Custom)] + public void GetsKnownMethod(string input, bool expectedResult, string expectedKnownString, HttpMethod expectedMethod) + { + // Arrange + var block = new Span(Encoding.ASCII.GetBytes(input)); + + // Act + HttpMethod knownMethod; + var result = block.GetKnownMethod(out knownMethod, out var length); + + string toString = null; + if (knownMethod != HttpMethod.Custom) + { + toString = HttpUtilities.MethodToString(knownMethod); + } + + + // Assert + Assert.Equal(expectedResult, result); + Assert.Equal(expectedMethod, knownMethod); + Assert.Equal(toString, expectedKnownString); + Assert.Equal(length, expectedKnownString?.Length ?? 0); + } + + [Theory] + [InlineData("HTTP/1.0\r", true, HttpUtilities.Http10Version, HttpVersion.Http10)] + [InlineData("HTTP/1.1\r", true, HttpUtilities.Http11Version, HttpVersion.Http11)] + [InlineData("HTTP/3.0\r", false, null, HttpVersion.Unknown)] + [InlineData("http/1.0\r", false, null, HttpVersion.Unknown)] + [InlineData("http/1.1\r", false, null, HttpVersion.Unknown)] + [InlineData("short ", false, null, HttpVersion.Unknown)] + public void GetsKnownVersion(string input, bool expectedResult, string expectedKnownString, HttpVersion version) + { + // Arrange + var block = new Span(Encoding.ASCII.GetBytes(input)); + + // Act + var result = block.GetKnownVersion(out HttpVersion knownVersion, out var length); + string toString = null; + if (knownVersion != HttpVersion.Unknown) + { + toString = HttpUtilities.VersionToString(knownVersion); + } + + // Assert + Assert.Equal(version, knownVersion); + Assert.Equal(expectedResult, result); + Assert.Equal(expectedKnownString, toString); + Assert.Equal(expectedKnownString?.Length ?? 0, length); + } + + [Theory] + [InlineData("HTTP/1.0\r", "HTTP/1.0")] + [InlineData("HTTP/1.1\r", "HTTP/1.1")] + public void KnownVersionsAreInterned(string input, string expected) + { + TestKnownStringsInterning(input, expected, span => + { + HttpUtilities.GetKnownVersion(span, out var version, out var _); + return HttpUtilities.VersionToString(version); + }); + } + + [Theory] + [InlineData("https://host/", "https://")] + [InlineData("http://host/", "http://")] + public void KnownSchemesAreInterned(string input, string expected) + { + TestKnownStringsInterning(input, expected, span => + { + HttpUtilities.GetKnownHttpScheme(span, out var scheme); + return HttpUtilities.SchemeToString(scheme); + }); + } + + [Theory] + [InlineData("CONNECT / HTTP/1.1", "CONNECT")] + [InlineData("DELETE / HTTP/1.1", "DELETE")] + [InlineData("GET / HTTP/1.1", "GET")] + [InlineData("HEAD / HTTP/1.1", "HEAD")] + [InlineData("PATCH / HTTP/1.1", "PATCH")] + [InlineData("POST / HTTP/1.1", "POST")] + [InlineData("PUT / HTTP/1.1", "PUT")] + [InlineData("OPTIONS / HTTP/1.1", "OPTIONS")] + [InlineData("TRACE / HTTP/1.1", "TRACE")] + public void KnownMethodsAreInterned(string input, string expected) + { + TestKnownStringsInterning(input, expected, span => + { + HttpUtilities.GetKnownMethod(span, out var method, out var length); + return HttpUtilities.MethodToString(method); + }); + } + + private void TestKnownStringsInterning(string input, string expected, Func action) + { + // Act + var knownString1 = action(Encoding.ASCII.GetBytes(input)); + var knownString2 = action(Encoding.ASCII.GetBytes(input)); + + // Assert + Assert.Equal(knownString1, expected); + Assert.Same(knownString1, knownString2); + } + + public static TheoryData HostHeaderData + { + get + { + return new TheoryData() { + "z", + "1", + "y:1", + "1:1", + "[ABCdef]", + "[abcDEF]:0", + "[abcdef:127.2355.1246.114]:0", + "[::1]:80", + "127.0.0.1:80", + "900.900.900.900:9523547852", + "foo", + "foo:234", + "foo.bar.baz", + "foo.BAR.baz:46245", + "foo.ba-ar.baz:46245", + "-foo:1234", + "xn--asdfaf:134", + "-", + "_", + "~", + "!", + "$", + "'", + "(", + ")", + }; + } + } + + [Theory] + [MemberData(nameof(HostHeaderData))] + public void ValidHostHeadersParsed(string host) + { + HttpUtilities.ValidateHostHeader(host); + // Shouldn't throw + } + + public static TheoryData HostHeaderInvalidData + { + get + { + // see https://tools.ietf.org/html/rfc7230#section-5.4 + var data = new TheoryData() { + "[]", // Too short + "[::]", // Too short + "[ghijkl]", // Non-hex + "[afd:adf:123", // Incomplete + "[afd:adf]123", // Missing : + "[afd:adf]:", // Missing port digits + "[afd adf]", // Space + "[ad-314]", // dash + ":1234", // Missing host + "a:b:c", // Missing [] + "::1", // Missing [] + "::", // Missing everything + "abcd:1abcd", // Letters in port + "abcd:1.2", // Dot in port + "1.2.3.4:", // Missing port digits + "1.2 .4", // Space + }; + + // These aren't allowed anywhere in the host header + var invalid = "\"#%*+,/;<=>?@[]\\^`{}|"; + foreach (var ch in invalid) + { + data.Add(ch.ToString()); + } + + invalid = "!\"#$%&'()*+,/;<=>?@[]\\^_`{}|~-"; + foreach (var ch in invalid) + { + data.Add("[abd" + ch + "]:1234"); + } + + invalid = "!\"#$%&'()*+,/;<=>?@[]\\^_`{}|~:abcABC-."; + foreach (var ch in invalid) + { + data.Add("a.b.c:" + ch); + } + + return data; + } + } + + [Theory] + [MemberData(nameof(HostHeaderInvalidData))] + public void InvalidHostHeadersRejected(string host) + { + Assert.Throws(() => HttpUtilities.ValidateHostHeader(host)); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/test/HuffmanTests.cs b/src/Servers/Kestrel/Core/test/HuffmanTests.cs new file mode 100644 index 0000000000..cbe87cd4d9 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/HuffmanTests.cs @@ -0,0 +1,451 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HuffmanTests + { + public static readonly TheoryData _validData = new TheoryData + { + // Single 5-bit symbol + { new byte[] { 0x07 }, Encoding.ASCII.GetBytes("0") }, + // Single 6-bit symbol + { new byte[] { 0x57 }, Encoding.ASCII.GetBytes("%") }, + // Single 7-bit symbol + { new byte[] { 0xb9 }, Encoding.ASCII.GetBytes(":") }, + // Single 8-bit symbol + { new byte[] { 0xf8 }, Encoding.ASCII.GetBytes("&") }, + // Single 10-bit symbol + { new byte[] { 0xfe, 0x3f }, Encoding.ASCII.GetBytes("!") }, + // Single 11-bit symbol + { new byte[] { 0xff, 0x7f }, Encoding.ASCII.GetBytes("+") }, + // Single 12-bit symbol + { new byte[] { 0xff, 0xaf }, Encoding.ASCII.GetBytes("#") }, + // Single 13-bit symbol + { new byte[] { 0xff, 0xcf }, Encoding.ASCII.GetBytes("$") }, + // Single 14-bit symbol + { new byte[] { 0xff, 0xf3 }, Encoding.ASCII.GetBytes("^") }, + // Single 15-bit symbol + { new byte[] { 0xff, 0xf9 }, Encoding.ASCII.GetBytes("<") }, + // Single 19-bit symbol + { new byte[] { 0xff, 0xfe, 0x1f }, Encoding.ASCII.GetBytes("\\") }, + // Single 20-bit symbol + { new byte[] { 0xff, 0xfe, 0x6f }, new byte[] { 0x80 } }, + // Single 21-bit symbol + { new byte[] { 0xff, 0xfe, 0xe7 }, new byte[] { 0x99 } }, + // Single 22-bit symbol + { new byte[] { 0xff, 0xff, 0x4b }, new byte[] { 0x81 } }, + // Single 23-bit symbol + { new byte[] { 0xff, 0xff, 0xb1 }, new byte[] { 0x01 } }, + // Single 24-bit symbol + { new byte[] { 0xff, 0xff, 0xea }, new byte[] { 0x09 } }, + // Single 25-bit symbol + { new byte[] { 0xff, 0xff, 0xf6, 0x7f }, new byte[] { 0xc7 } }, + // Single 26-bit symbol + { new byte[] { 0xff, 0xff, 0xf8, 0x3f }, new byte[] { 0xc0 } }, + // Single 27-bit symbol + { new byte[] { 0xff, 0xff, 0xfb, 0xdf }, new byte[] { 0xcb } }, + // Single 28-bit symbol + { new byte[] { 0xff, 0xff, 0xfe, 0x2f }, new byte[] { 0x02 } }, + // Single 30-bit symbol + { new byte[] { 0xff, 0xff, 0xff, 0xf3 }, new byte[] { 0x0a } }, + + // h e l l o * + { new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1111 }, Encoding.ASCII.GetBytes("hello") }, + + // Sequences that uncovered errors + { new byte[] { 0xb6, 0xb9, 0xac, 0x1c, 0x85, 0x58, 0xd5, 0x20, 0xa4, 0xb6, 0xc2, 0xad, 0x61, 0x7b, 0x5a, 0x54, 0x25, 0x1f }, Encoding.ASCII.GetBytes("upgrade-insecure-requests") }, + { new byte[] { 0xfe, 0x53 }, Encoding.ASCII.GetBytes("\"t") } + }; + + [Theory] + [MemberData(nameof(_validData))] + public void HuffmanDecodeArray(byte[] encoded, byte[] expected) + { + var dst = new byte[expected.Length]; + Assert.Equal(expected.Length, Huffman.Decode(encoded, 0, encoded.Length, dst)); + Assert.Equal(expected, dst); + } + + public static readonly TheoryData _longPaddingData = new TheoryData + { + // h e l l o * + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1111, 0b11111111 }, + + // '&' (8 bits) + 8 bit padding + new byte[] { 0xf8, 0xff }, + + // ':' (7 bits) + 9 bit padding + new byte[] { 0xb9, 0xff } + }; + + [Theory] + [MemberData(nameof(_longPaddingData))] + public void ThrowsOnPaddingLongerThanSevenBits(byte[] encoded) + { + var exception = Assert.Throws(() => Huffman.Decode(encoded, 0, encoded.Length, new byte[encoded.Length * 2])); + Assert.Equal(CoreStrings.HPackHuffmanErrorIncomplete, exception.Message); + } + + public static readonly TheoryData _eosData = new TheoryData + { + // EOS + new byte[] { 0xff, 0xff, 0xff, 0xff }, + // '&' + EOS + '0' + new byte[] { 0xf8, 0xff, 0xff, 0xff, 0xfc, 0x1f } + }; + + [Theory] + [MemberData(nameof(_eosData))] + public void ThrowsOnEOS(byte[] encoded) + { + var exception = Assert.Throws(() => Huffman.Decode(encoded, 0, encoded.Length, new byte[encoded.Length * 2])); + Assert.Equal(CoreStrings.HPackHuffmanErrorEOS, exception.Message); + } + + [Fact] + public void ThrowsOnDestinationBufferTooSmall() + { + // h e l l o * + var encoded = new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1111 }; + var exception = Assert.Throws(() => Huffman.Decode(encoded, 0, encoded.Length, new byte[encoded.Length])); + Assert.Equal(CoreStrings.HPackHuffmanErrorDestinationTooSmall, exception.Message); + } + + public static readonly TheoryData _incompleteSymbolData = new TheoryData + { + // h e l l o (incomplete) + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0 }, + + // Non-zero padding will be seen as incomplete symbol + // h e l l o * + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0000 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0001 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0010 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0011 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0100 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0101 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0110 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_0111 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1000 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1001 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1010 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1011 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1100 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1101 }, + new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1110 } + }; + + [Theory] + [MemberData(nameof(_incompleteSymbolData))] + public void ThrowsOnIncompleteSymbol(byte[] encoded) + { + var exception = Assert.Throws(() => Huffman.Decode(encoded, 0, encoded.Length, new byte[encoded.Length * 2])); + Assert.Equal(CoreStrings.HPackHuffmanErrorIncomplete, exception.Message); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanEncode(int code, uint expectedEncoded, int expectedBitLength) + { + var (encoded, bitLength) = Huffman.Encode(code); + Assert.Equal(expectedEncoded, encoded); + Assert.Equal(expectedBitLength, bitLength); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanDecode(int code, uint encoded, int bitLength) + { + Assert.Equal(code, Huffman.Decode(encoded, bitLength, out var decodedBits)); + Assert.Equal(bitLength, decodedBits); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanEncodeDecode( + int code, +// Suppresses the warning about an unused theory parameter because +// this test shares data with other methods +#pragma warning disable xUnit1026 + uint encoded, +#pragma warning restore xUnit1026 + int bitLength) + { + Assert.Equal(code, Huffman.Decode(Huffman.Encode(code).encoded, bitLength, out var decodedBits)); + Assert.Equal(bitLength, decodedBits); + } + + public static TheoryData HuffmanData + { + get + { + var data = new TheoryData(); + + data.Add(0, 0b11111111_11000000_00000000_00000000, 13); + data.Add(1, 0b11111111_11111111_10110000_00000000, 23); + data.Add(2, 0b11111111_11111111_11111110_00100000, 28); + data.Add(3, 0b11111111_11111111_11111110_00110000, 28); + data.Add(4, 0b11111111_11111111_11111110_01000000, 28); + data.Add(5, 0b11111111_11111111_11111110_01010000, 28); + data.Add(6, 0b11111111_11111111_11111110_01100000, 28); + data.Add(7, 0b11111111_11111111_11111110_01110000, 28); + data.Add(8, 0b11111111_11111111_11111110_10000000, 28); + data.Add(9, 0b11111111_11111111_11101010_00000000, 24); + data.Add(10, 0b11111111_11111111_11111111_11110000, 30); + data.Add(11, 0b11111111_11111111_11111110_10010000, 28); + data.Add(12, 0b11111111_11111111_11111110_10100000, 28); + data.Add(13, 0b11111111_11111111_11111111_11110100, 30); + data.Add(14, 0b11111111_11111111_11111110_10110000, 28); + data.Add(15, 0b11111111_11111111_11111110_11000000, 28); + data.Add(16, 0b11111111_11111111_11111110_11010000, 28); + data.Add(17, 0b11111111_11111111_11111110_11100000, 28); + data.Add(18, 0b11111111_11111111_11111110_11110000, 28); + data.Add(19, 0b11111111_11111111_11111111_00000000, 28); + data.Add(20, 0b11111111_11111111_11111111_00010000, 28); + data.Add(21, 0b11111111_11111111_11111111_00100000, 28); + data.Add(22, 0b11111111_11111111_11111111_11111000, 30); + data.Add(23, 0b11111111_11111111_11111111_00110000, 28); + data.Add(24, 0b11111111_11111111_11111111_01000000, 28); + data.Add(25, 0b11111111_11111111_11111111_01010000, 28); + data.Add(26, 0b11111111_11111111_11111111_01100000, 28); + data.Add(27, 0b11111111_11111111_11111111_01110000, 28); + data.Add(28, 0b11111111_11111111_11111111_10000000, 28); + data.Add(29, 0b11111111_11111111_11111111_10010000, 28); + data.Add(30, 0b11111111_11111111_11111111_10100000, 28); + data.Add(31, 0b11111111_11111111_11111111_10110000, 28); + data.Add(32, 0b01010000_00000000_00000000_00000000, 6); + data.Add(33, 0b11111110_00000000_00000000_00000000, 10); + data.Add(34, 0b11111110_01000000_00000000_00000000, 10); + data.Add(35, 0b11111111_10100000_00000000_00000000, 12); + data.Add(36, 0b11111111_11001000_00000000_00000000, 13); + data.Add(37, 0b01010100_00000000_00000000_00000000, 6); + data.Add(38, 0b11111000_00000000_00000000_00000000, 8); + data.Add(39, 0b11111111_01000000_00000000_00000000, 11); + data.Add(40, 0b11111110_10000000_00000000_00000000, 10); + data.Add(41, 0b11111110_11000000_00000000_00000000, 10); + data.Add(42, 0b11111001_00000000_00000000_00000000, 8); + data.Add(43, 0b11111111_01100000_00000000_00000000, 11); + data.Add(44, 0b11111010_00000000_00000000_00000000, 8); + data.Add(45, 0b01011000_00000000_00000000_00000000, 6); + data.Add(46, 0b01011100_00000000_00000000_00000000, 6); + data.Add(47, 0b01100000_00000000_00000000_00000000, 6); + data.Add(48, 0b00000000_00000000_00000000_00000000, 5); + data.Add(49, 0b00001000_00000000_00000000_00000000, 5); + data.Add(50, 0b00010000_00000000_00000000_00000000, 5); + data.Add(51, 0b01100100_00000000_00000000_00000000, 6); + data.Add(52, 0b01101000_00000000_00000000_00000000, 6); + data.Add(53, 0b01101100_00000000_00000000_00000000, 6); + data.Add(54, 0b01110000_00000000_00000000_00000000, 6); + data.Add(55, 0b01110100_00000000_00000000_00000000, 6); + data.Add(56, 0b01111000_00000000_00000000_00000000, 6); + data.Add(57, 0b01111100_00000000_00000000_00000000, 6); + data.Add(58, 0b10111000_00000000_00000000_00000000, 7); + data.Add(59, 0b11111011_00000000_00000000_00000000, 8); + data.Add(60, 0b11111111_11111000_00000000_00000000, 15); + data.Add(61, 0b10000000_00000000_00000000_00000000, 6); + data.Add(62, 0b11111111_10110000_00000000_00000000, 12); + data.Add(63, 0b11111111_00000000_00000000_00000000, 10); + data.Add(64, 0b11111111_11010000_00000000_00000000, 13); + data.Add(65, 0b10000100_00000000_00000000_00000000, 6); + data.Add(66, 0b10111010_00000000_00000000_00000000, 7); + data.Add(67, 0b10111100_00000000_00000000_00000000, 7); + data.Add(68, 0b10111110_00000000_00000000_00000000, 7); + data.Add(69, 0b11000000_00000000_00000000_00000000, 7); + data.Add(70, 0b11000010_00000000_00000000_00000000, 7); + data.Add(71, 0b11000100_00000000_00000000_00000000, 7); + data.Add(72, 0b11000110_00000000_00000000_00000000, 7); + data.Add(73, 0b11001000_00000000_00000000_00000000, 7); + data.Add(74, 0b11001010_00000000_00000000_00000000, 7); + data.Add(75, 0b11001100_00000000_00000000_00000000, 7); + data.Add(76, 0b11001110_00000000_00000000_00000000, 7); + data.Add(77, 0b11010000_00000000_00000000_00000000, 7); + data.Add(78, 0b11010010_00000000_00000000_00000000, 7); + data.Add(79, 0b11010100_00000000_00000000_00000000, 7); + data.Add(80, 0b11010110_00000000_00000000_00000000, 7); + data.Add(81, 0b11011000_00000000_00000000_00000000, 7); + data.Add(82, 0b11011010_00000000_00000000_00000000, 7); + data.Add(83, 0b11011100_00000000_00000000_00000000, 7); + data.Add(84, 0b11011110_00000000_00000000_00000000, 7); + data.Add(85, 0b11100000_00000000_00000000_00000000, 7); + data.Add(86, 0b11100010_00000000_00000000_00000000, 7); + data.Add(87, 0b11100100_00000000_00000000_00000000, 7); + data.Add(88, 0b11111100_00000000_00000000_00000000, 8); + data.Add(89, 0b11100110_00000000_00000000_00000000, 7); + data.Add(90, 0b11111101_00000000_00000000_00000000, 8); + data.Add(91, 0b11111111_11011000_00000000_00000000, 13); + data.Add(92, 0b11111111_11111110_00000000_00000000, 19); + data.Add(93, 0b11111111_11100000_00000000_00000000, 13); + data.Add(94, 0b11111111_11110000_00000000_00000000, 14); + data.Add(95, 0b10001000_00000000_00000000_00000000, 6); + data.Add(96, 0b11111111_11111010_00000000_00000000, 15); + data.Add(97, 0b00011000_00000000_00000000_00000000, 5); + data.Add(98, 0b10001100_00000000_00000000_00000000, 6); + data.Add(99, 0b00100000_00000000_00000000_00000000, 5); + data.Add(100, 0b10010000_00000000_00000000_00000000, 6); + data.Add(101, 0b00101000_00000000_00000000_00000000, 5); + data.Add(102, 0b10010100_00000000_00000000_00000000, 6); + data.Add(103, 0b10011000_00000000_00000000_00000000, 6); + data.Add(104, 0b10011100_00000000_00000000_00000000, 6); + data.Add(105, 0b00110000_00000000_00000000_00000000, 5); + data.Add(106, 0b11101000_00000000_00000000_00000000, 7); + data.Add(107, 0b11101010_00000000_00000000_00000000, 7); + data.Add(108, 0b10100000_00000000_00000000_00000000, 6); + data.Add(109, 0b10100100_00000000_00000000_00000000, 6); + data.Add(110, 0b10101000_00000000_00000000_00000000, 6); + data.Add(111, 0b00111000_00000000_00000000_00000000, 5); + data.Add(112, 0b10101100_00000000_00000000_00000000, 6); + data.Add(113, 0b11101100_00000000_00000000_00000000, 7); + data.Add(114, 0b10110000_00000000_00000000_00000000, 6); + data.Add(115, 0b01000000_00000000_00000000_00000000, 5); + data.Add(116, 0b01001000_00000000_00000000_00000000, 5); + data.Add(117, 0b10110100_00000000_00000000_00000000, 6); + data.Add(118, 0b11101110_00000000_00000000_00000000, 7); + data.Add(119, 0b11110000_00000000_00000000_00000000, 7); + data.Add(120, 0b11110010_00000000_00000000_00000000, 7); + data.Add(121, 0b11110100_00000000_00000000_00000000, 7); + data.Add(122, 0b11110110_00000000_00000000_00000000, 7); + data.Add(123, 0b11111111_11111100_00000000_00000000, 15); + data.Add(124, 0b11111111_10000000_00000000_00000000, 11); + data.Add(125, 0b11111111_11110100_00000000_00000000, 14); + data.Add(126, 0b11111111_11101000_00000000_00000000, 13); + data.Add(127, 0b11111111_11111111_11111111_11000000, 28); + data.Add(128, 0b11111111_11111110_01100000_00000000, 20); + data.Add(129, 0b11111111_11111111_01001000_00000000, 22); + data.Add(130, 0b11111111_11111110_01110000_00000000, 20); + data.Add(131, 0b11111111_11111110_10000000_00000000, 20); + data.Add(132, 0b11111111_11111111_01001100_00000000, 22); + data.Add(133, 0b11111111_11111111_01010000_00000000, 22); + data.Add(134, 0b11111111_11111111_01010100_00000000, 22); + data.Add(135, 0b11111111_11111111_10110010_00000000, 23); + data.Add(136, 0b11111111_11111111_01011000_00000000, 22); + data.Add(137, 0b11111111_11111111_10110100_00000000, 23); + data.Add(138, 0b11111111_11111111_10110110_00000000, 23); + data.Add(139, 0b11111111_11111111_10111000_00000000, 23); + data.Add(140, 0b11111111_11111111_10111010_00000000, 23); + data.Add(141, 0b11111111_11111111_10111100_00000000, 23); + data.Add(142, 0b11111111_11111111_11101011_00000000, 24); + data.Add(143, 0b11111111_11111111_10111110_00000000, 23); + data.Add(144, 0b11111111_11111111_11101100_00000000, 24); + data.Add(145, 0b11111111_11111111_11101101_00000000, 24); + data.Add(146, 0b11111111_11111111_01011100_00000000, 22); + data.Add(147, 0b11111111_11111111_11000000_00000000, 23); + data.Add(148, 0b11111111_11111111_11101110_00000000, 24); + data.Add(149, 0b11111111_11111111_11000010_00000000, 23); + data.Add(150, 0b11111111_11111111_11000100_00000000, 23); + data.Add(151, 0b11111111_11111111_11000110_00000000, 23); + data.Add(152, 0b11111111_11111111_11001000_00000000, 23); + data.Add(153, 0b11111111_11111110_11100000_00000000, 21); + data.Add(154, 0b11111111_11111111_01100000_00000000, 22); + data.Add(155, 0b11111111_11111111_11001010_00000000, 23); + data.Add(156, 0b11111111_11111111_01100100_00000000, 22); + data.Add(157, 0b11111111_11111111_11001100_00000000, 23); + data.Add(158, 0b11111111_11111111_11001110_00000000, 23); + data.Add(159, 0b11111111_11111111_11101111_00000000, 24); + data.Add(160, 0b11111111_11111111_01101000_00000000, 22); + data.Add(161, 0b11111111_11111110_11101000_00000000, 21); + data.Add(162, 0b11111111_11111110_10010000_00000000, 20); + data.Add(163, 0b11111111_11111111_01101100_00000000, 22); + data.Add(164, 0b11111111_11111111_01110000_00000000, 22); + data.Add(165, 0b11111111_11111111_11010000_00000000, 23); + data.Add(166, 0b11111111_11111111_11010010_00000000, 23); + data.Add(167, 0b11111111_11111110_11110000_00000000, 21); + data.Add(168, 0b11111111_11111111_11010100_00000000, 23); + data.Add(169, 0b11111111_11111111_01110100_00000000, 22); + data.Add(170, 0b11111111_11111111_01111000_00000000, 22); + data.Add(171, 0b11111111_11111111_11110000_00000000, 24); + data.Add(172, 0b11111111_11111110_11111000_00000000, 21); + data.Add(173, 0b11111111_11111111_01111100_00000000, 22); + data.Add(174, 0b11111111_11111111_11010110_00000000, 23); + data.Add(175, 0b11111111_11111111_11011000_00000000, 23); + data.Add(176, 0b11111111_11111111_00000000_00000000, 21); + data.Add(177, 0b11111111_11111111_00001000_00000000, 21); + data.Add(178, 0b11111111_11111111_10000000_00000000, 22); + data.Add(179, 0b11111111_11111111_00010000_00000000, 21); + data.Add(180, 0b11111111_11111111_11011010_00000000, 23); + data.Add(181, 0b11111111_11111111_10000100_00000000, 22); + data.Add(182, 0b11111111_11111111_11011100_00000000, 23); + data.Add(183, 0b11111111_11111111_11011110_00000000, 23); + data.Add(184, 0b11111111_11111110_10100000_00000000, 20); + data.Add(185, 0b11111111_11111111_10001000_00000000, 22); + data.Add(186, 0b11111111_11111111_10001100_00000000, 22); + data.Add(187, 0b11111111_11111111_10010000_00000000, 22); + data.Add(188, 0b11111111_11111111_11100000_00000000, 23); + data.Add(189, 0b11111111_11111111_10010100_00000000, 22); + data.Add(190, 0b11111111_11111111_10011000_00000000, 22); + data.Add(191, 0b11111111_11111111_11100010_00000000, 23); + data.Add(192, 0b11111111_11111111_11111000_00000000, 26); + data.Add(193, 0b11111111_11111111_11111000_01000000, 26); + data.Add(194, 0b11111111_11111110_10110000_00000000, 20); + data.Add(195, 0b11111111_11111110_00100000_00000000, 19); + data.Add(196, 0b11111111_11111111_10011100_00000000, 22); + data.Add(197, 0b11111111_11111111_11100100_00000000, 23); + data.Add(198, 0b11111111_11111111_10100000_00000000, 22); + data.Add(199, 0b11111111_11111111_11110110_00000000, 25); + data.Add(200, 0b11111111_11111111_11111000_10000000, 26); + data.Add(201, 0b11111111_11111111_11111000_11000000, 26); + data.Add(202, 0b11111111_11111111_11111001_00000000, 26); + data.Add(203, 0b11111111_11111111_11111011_11000000, 27); + data.Add(204, 0b11111111_11111111_11111011_11100000, 27); + data.Add(205, 0b11111111_11111111_11111001_01000000, 26); + data.Add(206, 0b11111111_11111111_11110001_00000000, 24); + data.Add(207, 0b11111111_11111111_11110110_10000000, 25); + data.Add(208, 0b11111111_11111110_01000000_00000000, 19); + data.Add(209, 0b11111111_11111111_00011000_00000000, 21); + data.Add(210, 0b11111111_11111111_11111001_10000000, 26); + data.Add(211, 0b11111111_11111111_11111100_00000000, 27); + data.Add(212, 0b11111111_11111111_11111100_00100000, 27); + data.Add(213, 0b11111111_11111111_11111001_11000000, 26); + data.Add(214, 0b11111111_11111111_11111100_01000000, 27); + data.Add(215, 0b11111111_11111111_11110010_00000000, 24); + data.Add(216, 0b11111111_11111111_00100000_00000000, 21); + data.Add(217, 0b11111111_11111111_00101000_00000000, 21); + data.Add(218, 0b11111111_11111111_11111010_00000000, 26); + data.Add(219, 0b11111111_11111111_11111010_01000000, 26); + data.Add(220, 0b11111111_11111111_11111111_11010000, 28); + data.Add(221, 0b11111111_11111111_11111100_01100000, 27); + data.Add(222, 0b11111111_11111111_11111100_10000000, 27); + data.Add(223, 0b11111111_11111111_11111100_10100000, 27); + data.Add(224, 0b11111111_11111110_11000000_00000000, 20); + data.Add(225, 0b11111111_11111111_11110011_00000000, 24); + data.Add(226, 0b11111111_11111110_11010000_00000000, 20); + data.Add(227, 0b11111111_11111111_00110000_00000000, 21); + data.Add(228, 0b11111111_11111111_10100100_00000000, 22); + data.Add(229, 0b11111111_11111111_00111000_00000000, 21); + data.Add(230, 0b11111111_11111111_01000000_00000000, 21); + data.Add(231, 0b11111111_11111111_11100110_00000000, 23); + data.Add(232, 0b11111111_11111111_10101000_00000000, 22); + data.Add(233, 0b11111111_11111111_10101100_00000000, 22); + data.Add(234, 0b11111111_11111111_11110111_00000000, 25); + data.Add(235, 0b11111111_11111111_11110111_10000000, 25); + data.Add(236, 0b11111111_11111111_11110100_00000000, 24); + data.Add(237, 0b11111111_11111111_11110101_00000000, 24); + data.Add(238, 0b11111111_11111111_11111010_10000000, 26); + data.Add(239, 0b11111111_11111111_11101000_00000000, 23); + data.Add(240, 0b11111111_11111111_11111010_11000000, 26); + data.Add(241, 0b11111111_11111111_11111100_11000000, 27); + data.Add(242, 0b11111111_11111111_11111011_00000000, 26); + data.Add(243, 0b11111111_11111111_11111011_01000000, 26); + data.Add(244, 0b11111111_11111111_11111100_11100000, 27); + data.Add(245, 0b11111111_11111111_11111101_00000000, 27); + data.Add(246, 0b11111111_11111111_11111101_00100000, 27); + data.Add(247, 0b11111111_11111111_11111101_01000000, 27); + data.Add(248, 0b11111111_11111111_11111101_01100000, 27); + data.Add(249, 0b11111111_11111111_11111111_11100000, 28); + data.Add(250, 0b11111111_11111111_11111101_10000000, 27); + data.Add(251, 0b11111111_11111111_11111101_10100000, 27); + data.Add(252, 0b11111111_11111111_11111101_11000000, 27); + data.Add(253, 0b11111111_11111111_11111101_11100000, 27); + data.Add(254, 0b11111111_11111111_11111110_00000000, 27); + data.Add(255, 0b11111111_11111111_11111011_10000000, 26); + data.Add(256, 0b11111111_11111111_11111111_11111100, 30); + + return data; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/IntegerDecoderTests.cs b/src/Servers/Kestrel/Core/test/IntegerDecoderTests.cs new file mode 100644 index 0000000000..b89b0f84a3 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/IntegerDecoderTests.cs @@ -0,0 +1,51 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class IntegerDecoderTests + { + [Theory] + [MemberData(nameof(IntegerData))] + public void IntegerDecode(int i, int prefixLength, byte[] octets) + { + var decoder = new IntegerDecoder(); + var result = decoder.BeginDecode(octets[0], prefixLength); + + if (octets.Length == 1) + { + Assert.True(result); + } + else + { + var j = 1; + + for (; j < octets.Length - 1; j++) + { + Assert.False(decoder.Decode(octets[j])); + } + + Assert.True(decoder.Decode(octets[j])); + } + + Assert.Equal(i, decoder.Value); + } + + public static TheoryData IntegerData + { + get + { + var data = new TheoryData(); + + data.Add(10, 5, new byte[] { 10 }); + data.Add(1337, 5, new byte[] { 0x1f, 0x9a, 0x0a }); + data.Add(42, 8, new byte[] { 42 }); + + return data; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/IntegerEncoderTests.cs b/src/Servers/Kestrel/Core/test/IntegerEncoderTests.cs new file mode 100644 index 0000000000..c667cc6cee --- /dev/null +++ b/src/Servers/Kestrel/Core/test/IntegerEncoderTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class IntegerEncoderTests + { + [Theory] + [MemberData(nameof(IntegerData))] + public void IntegerEncode(int i, int prefixLength, byte[] expectedOctets) + { + var buffer = new byte[expectedOctets.Length]; + + Assert.True(IntegerEncoder.Encode(i, prefixLength, buffer, out var octets)); + Assert.Equal(expectedOctets.Length, octets); + Assert.Equal(expectedOctets, buffer); + } + + public static TheoryData IntegerData + { + get + { + var data = new TheoryData(); + + data.Add(10, 5, new byte[] { 10 }); + data.Add(1337, 5, new byte[] { 0x1f, 0x9a, 0x0a }); + data.Add(42, 8, new byte[] { 42 }); + + return data; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/KestrelEventSourceTests.cs b/src/Servers/Kestrel/Core/test/KestrelEventSourceTests.cs new file mode 100644 index 0000000000..7f407c92fb --- /dev/null +++ b/src/Servers/Kestrel/Core/test/KestrelEventSourceTests.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics.Tracing; +using System.Reflection; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class KestrelEventSourceTests + { + [Fact] + public void ExistsWithCorrectId() + { + var esType = typeof(KestrelServer).GetTypeInfo().Assembly.GetType( + "Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.KestrelEventSource", + throwOnError: true, + ignoreCase: false + ); + + Assert.NotNull(esType); + + Assert.Equal("Microsoft-AspNetCore-Server-Kestrel", EventSource.GetName(esType)); + Assert.Equal(Guid.Parse("bdeb4676-a36e-5442-db99-4764e2326c7d"), EventSource.GetGuid(esType)); + Assert.NotEmpty(EventSource.GenerateManifest(esType, "assemblyPathToIncludeInManifest")); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/KestrelServerLimitsTests.cs b/src/Servers/Kestrel/Core/test/KestrelServerLimitsTests.cs new file mode 100644 index 0000000000..d71642f9dd --- /dev/null +++ b/src/Servers/Kestrel/Core/test/KestrelServerLimitsTests.cs @@ -0,0 +1,324 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class KestrelServerLimitsTests + { + [Fact] + public void MaxResponseBufferSizeDefault() + { + Assert.Equal(64 * 1024, (new KestrelServerLimits()).MaxResponseBufferSize); + } + + [Theory] + [InlineData((long)-1)] + [InlineData(long.MinValue)] + public void MaxResponseBufferSizeInvalid(long value) + { + Assert.Throws(() => + { + (new KestrelServerLimits()).MaxResponseBufferSize = value; + }); + } + + [Theory] + [InlineData(null)] + [InlineData((long)0)] + [InlineData((long)1)] + [InlineData(long.MaxValue)] + public void MaxResponseBufferSizeValid(long? value) + { + var o = new KestrelServerLimits(); + o.MaxResponseBufferSize = value; + Assert.Equal(value, o.MaxResponseBufferSize); + } + + [Fact] + public void MaxRequestBufferSizeDefault() + { + Assert.Equal(1024 * 1024, (new KestrelServerLimits()).MaxRequestBufferSize); + } + + [Theory] + [InlineData(-1)] + [InlineData(0)] + public void MaxRequestBufferSizeInvalid(int value) + { + Assert.Throws(() => + { + (new KestrelServerLimits()).MaxRequestBufferSize = value; + }); + } + + [Theory] + [InlineData(null)] + [InlineData(1)] + public void MaxRequestBufferSizeValid(int? value) + { + var o = new KestrelServerLimits(); + o.MaxRequestBufferSize = value; + Assert.Equal(value, o.MaxRequestBufferSize); + } + + [Fact] + public void MaxRequestLineSizeDefault() + { + Assert.Equal(8 * 1024, (new KestrelServerLimits()).MaxRequestLineSize); + } + + [Theory] + [InlineData(int.MinValue)] + [InlineData(-1)] + [InlineData(0)] + public void MaxRequestLineSizeInvalid(int value) + { + Assert.Throws(() => + { + (new KestrelServerLimits()).MaxRequestLineSize = value; + }); + } + + [Theory] + [InlineData(1)] + [InlineData(int.MaxValue)] + public void MaxRequestLineSizeValid(int value) + { + var o = new KestrelServerLimits(); + o.MaxRequestLineSize = value; + Assert.Equal(value, o.MaxRequestLineSize); + } + + [Fact] + public void MaxRequestHeadersTotalSizeDefault() + { + Assert.Equal(32 * 1024, (new KestrelServerLimits()).MaxRequestHeadersTotalSize); + } + + [Theory] + [InlineData(int.MinValue)] + [InlineData(-1)] + [InlineData(0)] + public void MaxRequestHeadersTotalSizeInvalid(int value) + { + var ex = Assert.Throws(() => new KestrelServerLimits().MaxRequestHeadersTotalSize = value); + Assert.StartsWith(CoreStrings.PositiveNumberRequired, ex.Message); + } + + [Theory] + [InlineData(1)] + [InlineData(int.MaxValue)] + public void MaxRequestHeadersTotalSizeValid(int value) + { + var o = new KestrelServerLimits(); + o.MaxRequestHeadersTotalSize = value; + Assert.Equal(value, o.MaxRequestHeadersTotalSize); + } + + [Fact] + public void MaxRequestHeaderCountDefault() + { + Assert.Equal(100, (new KestrelServerLimits()).MaxRequestHeaderCount); + } + + [Theory] + [InlineData(int.MinValue)] + [InlineData(-1)] + [InlineData(0)] + public void MaxRequestHeaderCountInvalid(int value) + { + Assert.Throws(() => + { + (new KestrelServerLimits()).MaxRequestHeaderCount = value; + }); + } + + [Theory] + [InlineData(1)] + [InlineData(int.MaxValue)] + public void MaxRequestHeaderCountValid(int value) + { + var o = new KestrelServerLimits(); + o.MaxRequestHeaderCount = value; + Assert.Equal(value, o.MaxRequestHeaderCount); + } + + [Fact] + public void KeepAliveTimeoutDefault() + { + Assert.Equal(TimeSpan.FromMinutes(2), new KestrelServerLimits().KeepAliveTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutValidData))] + public void KeepAliveTimeoutValid(TimeSpan value) + { + Assert.Equal(value, new KestrelServerLimits { KeepAliveTimeout = value }.KeepAliveTimeout); + } + + [Fact] + public void KeepAliveTimeoutCanBeSetToInfinite() + { + Assert.Equal(TimeSpan.MaxValue, new KestrelServerLimits { KeepAliveTimeout = Timeout.InfiniteTimeSpan }.KeepAliveTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutInvalidData))] + public void KeepAliveTimeoutInvalid(TimeSpan value) + { + var exception = Assert.Throws(() => new KestrelServerLimits { KeepAliveTimeout = value }); + + Assert.Equal("value", exception.ParamName); + Assert.StartsWith(CoreStrings.PositiveTimeSpanRequired, exception.Message); + } + + [Fact] + public void RequestHeadersTimeoutDefault() + { + Assert.Equal(TimeSpan.FromSeconds(30), new KestrelServerLimits().RequestHeadersTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutValidData))] + public void RequestHeadersTimeoutValid(TimeSpan value) + { + Assert.Equal(value, new KestrelServerLimits { RequestHeadersTimeout = value }.RequestHeadersTimeout); + } + + [Fact] + public void RequestHeadersTimeoutCanBeSetToInfinite() + { + Assert.Equal(TimeSpan.MaxValue, new KestrelServerLimits { RequestHeadersTimeout = Timeout.InfiniteTimeSpan }.RequestHeadersTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutInvalidData))] + public void RequestHeadersTimeoutInvalid(TimeSpan value) + { + var exception = Assert.Throws(() => new KestrelServerLimits { RequestHeadersTimeout = value }); + + Assert.Equal("value", exception.ParamName); + Assert.StartsWith(CoreStrings.PositiveTimeSpanRequired, exception.Message); + } + + [Fact] + public void MaxConnectionsDefault() + { + Assert.Null(new KestrelServerLimits().MaxConcurrentConnections); + Assert.Null(new KestrelServerLimits().MaxConcurrentUpgradedConnections); + } + + [Theory] + [InlineData(null)] + [InlineData(1u)] + [InlineData(long.MaxValue)] + public void MaxConnectionsValid(long? value) + { + var limits = new KestrelServerLimits + { + MaxConcurrentConnections = value + }; + + Assert.Equal(value, limits.MaxConcurrentConnections); + } + + [Theory] + [InlineData(long.MinValue)] + [InlineData(-1)] + [InlineData(0)] + public void MaxConnectionsInvalid(long value) + { + var ex = Assert.Throws(() => new KestrelServerLimits().MaxConcurrentConnections = value); + Assert.StartsWith(CoreStrings.PositiveNumberOrNullRequired, ex.Message); + } + + [Theory] + [InlineData(null)] + [InlineData(0)] + [InlineData(1)] + [InlineData(long.MaxValue)] + public void MaxUpgradedConnectionsValid(long? value) + { + var limits = new KestrelServerLimits + { + MaxConcurrentUpgradedConnections = value + }; + + Assert.Equal(value, limits.MaxConcurrentUpgradedConnections); + } + + + [Theory] + [InlineData(long.MinValue)] + [InlineData(-1)] + public void MaxUpgradedConnectionsInvalid(long value) + { + var ex = Assert.Throws(() => new KestrelServerLimits().MaxConcurrentUpgradedConnections = value); + Assert.StartsWith(CoreStrings.NonNegativeNumberOrNullRequired, ex.Message); + } + + [Fact] + public void MaxRequestBodySizeDefault() + { + // ~28.6 MB (https://www.iis.net/configreference/system.webserver/security/requestfiltering/requestlimits#005) + Assert.Equal(30000000, new KestrelServerLimits().MaxRequestBodySize); + } + + [Theory] + [InlineData(null)] + [InlineData(0)] + [InlineData(1)] + [InlineData(long.MaxValue)] + public void MaxRequestBodySizeValid(long? value) + { + var limits = new KestrelServerLimits + { + MaxRequestBodySize = value + }; + + Assert.Equal(value, limits.MaxRequestBodySize); + } + + [Theory] + [InlineData(long.MinValue)] + [InlineData(-1)] + public void MaxRequestBodySizeInvalid(long value) + { + var ex = Assert.Throws(() => new KestrelServerLimits().MaxRequestBodySize = value); + Assert.StartsWith(CoreStrings.NonNegativeNumberOrNullRequired, ex.Message); + } + + [Fact] + public void MinRequestBodyDataRateDefault() + { + Assert.NotNull(new KestrelServerLimits().MinRequestBodyDataRate); + Assert.Equal(240, new KestrelServerLimits().MinRequestBodyDataRate.BytesPerSecond); + Assert.Equal(TimeSpan.FromSeconds(5), new KestrelServerLimits().MinRequestBodyDataRate.GracePeriod); + } + + [Fact] + public void MinResponseBodyDataRateDefault() + { + Assert.NotNull(new KestrelServerLimits().MinResponseDataRate); + Assert.Equal(240, new KestrelServerLimits().MinResponseDataRate.BytesPerSecond); + Assert.Equal(TimeSpan.FromSeconds(5), new KestrelServerLimits().MinResponseDataRate.GracePeriod); + } + + public static TheoryData TimeoutValidData => new TheoryData + { + TimeSpan.FromTicks(1), + TimeSpan.MaxValue, + }; + + public static TheoryData TimeoutInvalidData => new TheoryData + { + TimeSpan.MinValue, + TimeSpan.FromTicks(-1), + TimeSpan.Zero + }; + } +} diff --git a/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs b/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs new file mode 100644 index 0000000000..75739a4c7b --- /dev/null +++ b/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class KestrelServerOptionsTests + { + [Fact] + public void NoDelayDefaultsToTrue() + { + var o1 = new KestrelServerOptions(); + o1.Listen(IPAddress.Loopback, 0); + o1.Listen(IPAddress.Loopback, 0, d => + { + d.NoDelay = false; + }); + + Assert.True(o1.ListenOptions[0].NoDelay); + Assert.False(o1.ListenOptions[1].NoDelay); + } + + [Fact] + public void AllowSynchronousIODefaultsToTrue() + { + var options = new KestrelServerOptions(); + + Assert.True(options.AllowSynchronousIO); + } + + [Fact] + public void ConfigureEndpointDefaultsAppliesToNewEndpoints() + { + var options = new KestrelServerOptions(); + options.ListenLocalhost(5000); + + Assert.True(options.ListenOptions[0].NoDelay); + + options.ConfigureEndpointDefaults(opt => + { + opt.NoDelay = false; + }); + + options.Listen(new IPEndPoint(IPAddress.Loopback, 5000), opt => + { + // ConfigureEndpointDefaults runs before this callback + Assert.False(opt.NoDelay); + }); + Assert.False(options.ListenOptions[1].NoDelay); + + options.ListenLocalhost(5000, opt => + { + Assert.False(opt.NoDelay); + opt.NoDelay = true; // Can be overriden + }); + Assert.True(options.ListenOptions[2].NoDelay); + + + options.ListenAnyIP(5000, opt => + { + Assert.False(opt.NoDelay); + }); + Assert.False(options.ListenOptions[3].NoDelay); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/test/KestrelServerTests.cs b/src/Servers/Kestrel/Core/test/KestrelServerTests.cs new file mode 100644 index 0000000000..c5fbd18402 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/KestrelServerTests.cs @@ -0,0 +1,353 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class KestrelServerTests + { + private KestrelServerOptions CreateServerOptions() + { + var serverOptions = new KestrelServerOptions(); + serverOptions.ApplicationServices = new ServiceCollection() + .AddLogging() + .BuildServiceProvider(); + return serverOptions; + } + + [Fact] + public void StartWithInvalidAddressThrows() + { + var testLogger = new TestApplicationErrorLogger { ThrowOnCriticalErrors = false }; + + using (var server = CreateServer(CreateServerOptions(), testLogger)) + { + server.Features.Get().Addresses.Add("http:/asdf"); + + var exception = Assert.Throws(() => StartDummyApplication(server)); + + Assert.Contains("Invalid URL", exception.Message); + Assert.Equal(1, testLogger.CriticalErrorsLogged); + } + } + + [Fact] + public void StartWithHttpsAddressConfiguresHttpsEndpoints() + { + var options = CreateServerOptions(); + options.DefaultCertificate = TestResources.GetTestCertificate(); + using (var server = CreateServer(options)) + { + server.Features.Get().Addresses.Add("https://127.0.0.1:0"); + + StartDummyApplication(server); + + Assert.True(server.Options.ListenOptions.Any()); + Assert.Contains(server.Options.ListenOptions[0].ConnectionAdapters, adapter => adapter.IsHttps); + } + } + + [Fact] + public void KestrelServerThrowsUsefulExceptionIfDefaultHttpsProviderNotAdded() + { + var options = CreateServerOptions(); + options.IsDevCertLoaded = true; // Prevent the system default from being loaded + using (var server = CreateServer(options, throwOnCriticalErrors: false)) + { + server.Features.Get().Addresses.Add("https://127.0.0.1:0"); + + var ex = Assert.Throws(() => StartDummyApplication(server)); + Assert.Equal(CoreStrings.NoCertSpecifiedNoDevelopmentCertificateFound, ex.Message); + } + } + + [Fact] + public void KestrelServerDoesNotThrowIfNoDefaultHttpsProviderButNoHttpUrls() + { + using (var server = CreateServer(CreateServerOptions())) + { + server.Features.Get().Addresses.Add("http://127.0.0.1:0"); + + StartDummyApplication(server); + } + } + + [Fact] + public void KestrelServerDoesNotThrowIfNoDefaultHttpsProviderButManualListenOptions() + { + var serverOptions = CreateServerOptions(); + serverOptions.Listen(new IPEndPoint(IPAddress.Loopback, 0)); + + using (var server = CreateServer(serverOptions)) + { + server.Features.Get().Addresses.Add("https://127.0.0.1:0"); + + StartDummyApplication(server); + } + } + + [Fact] + public void StartWithPathBaseInAddressThrows() + { + var testLogger = new TestApplicationErrorLogger { ThrowOnCriticalErrors = false }; + + using (var server = CreateServer(new KestrelServerOptions(), testLogger)) + { + server.Features.Get().Addresses.Add("http://127.0.0.1:0/base"); + + var exception = Assert.Throws(() => StartDummyApplication(server)); + + Assert.Equal( + $"A path base can only be configured using {nameof(IApplicationBuilder)}.UsePathBase().", + exception.Message); + Assert.Equal(1, testLogger.CriticalErrorsLogged); + } + } + + [Theory] + [InlineData("http://localhost:5000")] + [InlineData("The value of the string shouldn't matter.")] + [InlineData(null)] + public void StartWarnsWhenIgnoringIServerAddressesFeature(string ignoredAddress) + { + var testLogger = new TestApplicationErrorLogger(); + var kestrelOptions = new KestrelServerOptions(); + + // Directly configuring an endpoint using Listen causes the IServerAddressesFeature to be ignored. + kestrelOptions.Listen(IPAddress.Loopback, 0); + + using (var server = CreateServer(kestrelOptions, testLogger)) + { + server.Features.Get().Addresses.Add(ignoredAddress); + StartDummyApplication(server); + + var warning = testLogger.Messages.Single(log => log.LogLevel == LogLevel.Warning); + Assert.Contains("Overriding", warning.Message); + } + } + + [Theory] + [InlineData(1, 2)] + [InlineData(int.MaxValue - 1, int.MaxValue)] + public void StartWithMaxRequestBufferSizeLessThanMaxRequestLineSizeThrows(long maxRequestBufferSize, int maxRequestLineSize) + { + var testLogger = new TestApplicationErrorLogger { ThrowOnCriticalErrors = false }; + var options = new KestrelServerOptions + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = maxRequestLineSize + } + }; + + using (var server = CreateServer(options, testLogger)) + { + var exception = Assert.Throws(() => StartDummyApplication(server)); + + Assert.Equal( + CoreStrings.FormatMaxRequestBufferSmallerThanRequestLineBuffer(maxRequestBufferSize, maxRequestLineSize), + exception.Message); + Assert.Equal(1, testLogger.CriticalErrorsLogged); + } + } + + [Theory] + [InlineData(1, 2)] + [InlineData(int.MaxValue - 1, int.MaxValue)] + public void StartWithMaxRequestBufferSizeLessThanMaxRequestHeadersTotalSizeThrows(long maxRequestBufferSize, int maxRequestHeadersTotalSize) + { + var testLogger = new TestApplicationErrorLogger { ThrowOnCriticalErrors = false }; + var options = new KestrelServerOptions + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = (int)maxRequestBufferSize, + MaxRequestHeadersTotalSize = maxRequestHeadersTotalSize + } + }; + + using (var server = CreateServer(options, testLogger)) + { + var exception = Assert.Throws(() => StartDummyApplication(server)); + + Assert.Equal( + CoreStrings.FormatMaxRequestBufferSmallerThanRequestHeaderBuffer(maxRequestBufferSize, maxRequestHeadersTotalSize), + exception.Message); + Assert.Equal(1, testLogger.CriticalErrorsLogged); + } + } + + [Fact] + public void LoggerCategoryNameIsKestrelServerNamespace() + { + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + new KestrelServer(Options.Create(null), Mock.Of(), mockLoggerFactory.Object); + mockLoggerFactory.Verify(factory => factory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel")); + } + + [Fact] + public void StartWithNoTransportFactoryThrows() + { + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + var exception = Assert.Throws(() => + new KestrelServer(Options.Create(null), null, mockLoggerFactory.Object)); + + Assert.Equal("transportFactory", exception.ParamName); + } + + [Fact] + public async Task StopAsyncCallsCompleteWhenFirstCallCompletes() + { + var options = new KestrelServerOptions + { + ListenOptions = + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + } + }; + + var unbind = new SemaphoreSlim(0); + var stop = new SemaphoreSlim(0); + + var mockTransport = new Mock(); + mockTransport + .Setup(transport => transport.BindAsync()) + .Returns(Task.CompletedTask); + mockTransport + .Setup(transport => transport.UnbindAsync()) + .Returns(async () => await unbind.WaitAsync()); + mockTransport + .Setup(transport => transport.StopAsync()) + .Returns(async () => await stop.WaitAsync()); + + var mockTransportFactory = new Mock(); + mockTransportFactory + .Setup(transportFactory => transportFactory.Create(It.IsAny(), It.IsAny())) + .Returns(mockTransport.Object); + + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, mockLoggerFactory.Object); + await server.StartAsync(new DummyApplication(), CancellationToken.None); + + var stopTask1 = server.StopAsync(default(CancellationToken)); + var stopTask2 = server.StopAsync(default(CancellationToken)); + var stopTask3 = server.StopAsync(default(CancellationToken)); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + unbind.Release(); + stop.Release(); + + await Task.WhenAll(new[] { stopTask1, stopTask2, stopTask3 }).DefaultTimeout(); + + mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once); + mockTransport.Verify(transport => transport.StopAsync(), Times.Once); + } + + [Fact] + public async Task StopAsyncCallsCompleteWithThrownException() + { + var options = new KestrelServerOptions + { + ListenOptions = + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + } + }; + + var unbind = new SemaphoreSlim(0); + var unbindException = new InvalidOperationException(); + + var mockTransport = new Mock(); + mockTransport + .Setup(transport => transport.BindAsync()) + .Returns(Task.CompletedTask); + mockTransport + .Setup(transport => transport.UnbindAsync()) + .Returns(async () => + { + await unbind.WaitAsync(); + throw unbindException; + }); + mockTransport + .Setup(transport => transport.StopAsync()) + .Returns(Task.CompletedTask); + + var mockTransportFactory = new Mock(); + mockTransportFactory + .Setup(transportFactory => transportFactory.Create(It.IsAny(), It.IsAny())) + .Returns(mockTransport.Object); + + var mockLoggerFactory = new Mock(); + var mockLogger = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(mockLogger.Object); + var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, mockLoggerFactory.Object); + await server.StartAsync(new DummyApplication(), CancellationToken.None); + + var stopTask1 = server.StopAsync(default(CancellationToken)); + var stopTask2 = server.StopAsync(default(CancellationToken)); + var stopTask3 = server.StopAsync(default(CancellationToken)); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + unbind.Release(); + + var timeout = TestConstants.DefaultTimeout; + Assert.Same(unbindException, await Assert.ThrowsAsync(() => stopTask1.TimeoutAfter(timeout))); + Assert.Same(unbindException, await Assert.ThrowsAsync(() => stopTask2.TimeoutAfter(timeout))); + Assert.Same(unbindException, await Assert.ThrowsAsync(() => stopTask3.TimeoutAfter(timeout))); + + mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once); + } + + private static KestrelServer CreateServer(KestrelServerOptions options, ILogger testLogger) + { + return new KestrelServer(Options.Create(options), new MockTransportFactory(), new LoggerFactory(new[] { new KestrelTestLoggerProvider(testLogger) })); + } + + private static KestrelServer CreateServer(KestrelServerOptions options, bool throwOnCriticalErrors = true) + { + return new KestrelServer(Options.Create(options), new MockTransportFactory(), new LoggerFactory(new[] { new KestrelTestLoggerProvider(throwOnCriticalErrors) })); + } + + private static void StartDummyApplication(IServer server) + { + server.StartAsync(new DummyApplication(context => Task.CompletedTask), CancellationToken.None).GetAwaiter().GetResult(); + } + + private class MockTransportFactory : ITransportFactory + { + public ITransport Create(IEndPointInformation endPointInformation, IConnectionDispatcher handler) + { + return Mock.Of(); + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/KnownStringsTests.cs b/src/Servers/Kestrel/Core/test/KnownStringsTests.cs new file mode 100644 index 0000000000..75565d04f4 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/KnownStringsTests.cs @@ -0,0 +1,86 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.KestrelTests +{ + public class KnownStringsTests + { + static byte[] _methodConnect = Encoding.ASCII.GetBytes("CONNECT "); + static byte[] _methodDelete = Encoding.ASCII.GetBytes("DELETE \0"); + static byte[] _methodGet = Encoding.ASCII.GetBytes("GET "); + static byte[] _methodHead = Encoding.ASCII.GetBytes("HEAD \0\0\0"); + static byte[] _methodPatch = Encoding.ASCII.GetBytes("PATCH \0\0"); + static byte[] _methodPost = Encoding.ASCII.GetBytes("POST \0\0\0"); + static byte[] _methodPut = Encoding.ASCII.GetBytes("PUT \0\0\0\0"); + static byte[] _methodOptions = Encoding.ASCII.GetBytes("OPTIONS "); + static byte[] _methodTrace = Encoding.ASCII.GetBytes("TRACE \0\0"); + + const int MagicNumer = 0x0600000C; + static byte[] _invalidMethod1 = BitConverter.GetBytes((ulong)MagicNumer); + static byte[] _invalidMethod2 = { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }; + static byte[] _invalidMethod3 = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + static byte[] _invalidMethod4 = Encoding.ASCII.GetBytes("CONNECT_"); + static byte[] _invalidMethod5 = Encoding.ASCII.GetBytes("DELETE_\0"); + static byte[] _invalidMethod6 = Encoding.ASCII.GetBytes("GET_"); + static byte[] _invalidMethod7 = Encoding.ASCII.GetBytes("HEAD_\0\0\0"); + static byte[] _invalidMethod8 = Encoding.ASCII.GetBytes("PATCH_\0\0"); + static byte[] _invalidMethod9 = Encoding.ASCII.GetBytes("POST_\0\0\0"); + static byte[] _invalidMethod10 = Encoding.ASCII.GetBytes("PUT_\0\0\0\0"); + static byte[] _invalidMethod11 = Encoding.ASCII.GetBytes("OPTIONS_"); + static byte[] _invalidMethod12 = Encoding.ASCII.GetBytes("TRACE_\0\0"); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static object[] CreateTestDataEntry(byte[] methodData, HttpMethod expectedMethod, int expectedLength, bool expectedResult) + { + return new object[] { methodData, expectedMethod, expectedLength, expectedResult }; + } + + private static readonly object[][] _testData = new object[][] + { + CreateTestDataEntry(_methodGet, HttpMethod.Get, 3, true), + CreateTestDataEntry(_methodPut, HttpMethod.Put, 3, true), + CreateTestDataEntry(_methodPost, HttpMethod.Post, 4, true), + CreateTestDataEntry(_methodHead, HttpMethod.Head, 4, true), + CreateTestDataEntry(_methodTrace, HttpMethod.Trace, 5, true), + CreateTestDataEntry(_methodPatch, HttpMethod.Patch, 5, true), + CreateTestDataEntry(_methodDelete, HttpMethod.Delete, 6, true), + CreateTestDataEntry(_methodConnect, HttpMethod.Connect, 7, true), + CreateTestDataEntry(_methodOptions, HttpMethod.Options, 7, true), + CreateTestDataEntry(_invalidMethod1, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod2, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod3, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod4, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod5, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod6, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod7, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod8, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod9, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod10, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod11, HttpMethod.Custom, 0, false), + CreateTestDataEntry(_invalidMethod12, HttpMethod.Custom, 0, false), + }; + + public static IEnumerable TestData => _testData; + + [Theory] + [MemberData(nameof(TestData), MemberType = typeof(KnownStringsTests))] + public void GetsKnownMethod(byte[] methodData, HttpMethod expectedMethod, int expectedLength, bool expectedResult) + { + var data = new Span(methodData); + + var result = data.GetKnownMethod(out var method, out var length); + + Assert.Equal(expectedResult, result); + Assert.Equal(expectedMethod, method); + Assert.Equal(expectedLength, length); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/ListenOptionsTests.cs b/src/Servers/Kestrel/Core/test/ListenOptionsTests.cs new file mode 100644 index 0000000000..77e5a33bbf --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ListenOptionsTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ListenOptionsTests + { + [Fact] + public void ProtocolsDefault() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + Assert.Equal(HttpProtocols.Http1, listenOptions.Protocols); + } + + [Fact] + public void LocalHostListenOptionsClonesConnectionMiddleware() + { + var localhostListenOptions = new LocalhostListenOptions(1004); + localhostListenOptions.ConnectionAdapters.Add(new PassThroughConnectionAdapter()); + var serviceProvider = new ServiceCollection().BuildServiceProvider(); + localhostListenOptions.KestrelServerOptions = new KestrelServerOptions() + { + ApplicationServices = serviceProvider + }; + var middlewareRan = false; + localhostListenOptions.Use(next => + { + middlewareRan = true; + return context => Task.CompletedTask; + }); + + var clone = localhostListenOptions.Clone(IPAddress.IPv6Loopback); + var app = clone.Build(); + + // Execute the delegate + app(null); + + Assert.True(middlewareRan); + Assert.NotNull(clone.KestrelServerOptions); + Assert.NotNull(serviceProvider); + Assert.Same(serviceProvider, clone.ApplicationServices); + Assert.Single(clone.ConnectionAdapters); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs new file mode 100644 index 0000000000..5f006b12e3 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs @@ -0,0 +1,894 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.IO.Pipelines; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class MessageBodyTests + { + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanReadFromContentLength(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); + var stream = new HttpRequestStream(mockBodyControl.Object); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + + var count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(0, count); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanReadAsyncFromContentLength(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + + var count = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + count = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(0, count); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CanReadFromChunkedEncoding() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); + var stream = new HttpRequestStream(mockBodyControl.Object); + stream.StartAcceptingReads(body); + + input.Add("5\r\nHello\r\n"); + + var buffer = new byte[1024]; + + var count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Add("0\r\n\r\n"); + + count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(0, count); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CanReadAsyncFromChunkedEncoding() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("5\r\nHello\r\n"); + + var buffer = new byte[1024]; + + var count = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Add("0\r\n\r\n"); + + count = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(0, count); + + await body.StopAsync(); + } + } + + [Fact] + public async Task ReadExitsGivenIncompleteChunkedExtension() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("5;\r\0"); + + var buffer = new byte[1024]; + var readTask = stream.ReadAsync(buffer, 0, buffer.Length); + + Assert.False(readTask.IsCompleted); + + input.Add("\r\r\r\nHello\r\n0\r\n\r\n"); + + Assert.Equal(5, await readTask.DefaultTimeout()); + Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + + await body.StopAsync(); + } + } + + [Fact] + public async Task ReadThrowsGivenChunkPrefixGreaterThanMaxInt() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("80000000\r\n"); + + var buffer = new byte[1024]; + var ex = await Assert.ThrowsAsync(async () => + await stream.ReadAsync(buffer, 0, buffer.Length)); + Assert.IsType(ex.InnerException); + Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); + + await body.StopAsync(); + } + } + + [Fact] + public async Task ReadThrowsGivenChunkPrefixGreaterThan8Bytes() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("012345678\r"); + + var buffer = new byte[1024]; + var ex = await Assert.ThrowsAsync(async () => + await stream.ReadAsync(buffer, 0, buffer.Length)); + + Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanReadFromRemainingData(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); + var stream = new HttpRequestStream(mockBodyControl.Object); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + + var count = stream.Read(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Fin(); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task CanReadAsyncFromRemainingData(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + + var count = await stream.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(5, count); + AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Fin(); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders(), input.Http1Connection); + var mockBodyControl = new Mock(); + mockBodyControl.Setup(m => m.AllowSynchronousIO).Returns(true); + var stream = new HttpRequestStream(mockBodyControl.Object); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + Assert.Equal(0, stream.Read(buffer, 0, buffer.Length)); + + await body.StopAsync(); + } + } + + [Theory] + [InlineData(HttpVersion.Http10)] + [InlineData(HttpVersion.Http11)] + public async Task ReadAsyncFromNoContentLengthReturnsZero(HttpVersion httpVersion) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(httpVersion, new HttpRequestHeaders(), input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CanHandleLargeBlocks() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "8197" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + // Input needs to be greater than 4032 bytes to allocate a block not backed by a slab. + var largeInput = new string('a', 8192); + + input.Add(largeInput); + // Add a smaller block to the end so that SocketInput attempts to return the large + // block to the memory pool. + input.Add("Hello"); + + var ms = new MemoryStream(); + + await stream.CopyToAsync(ms); + var requestArray = ms.ToArray(); + Assert.Equal(8197, requestArray.Length); + AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); + + await body.StopAsync(); + } + } + + [Fact] + public void ForThrowsWhenFinalTransferCodingIsNotChunked() + { + using (var input = new TestInput()) + { + var ex = Assert.Throws(() => + Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked, not-chunked" }, input.Http1Connection)); + + Assert.Equal(StatusCodes.Status400BadRequest, ex.StatusCode); + Assert.Equal(CoreStrings.FormatBadRequest_FinalTransferCodingNotChunked("chunked, not-chunked"), ex.Message); + } + } + + [Theory] + [InlineData(HttpMethod.Post)] + [InlineData(HttpMethod.Put)] + public void ForThrowsWhenMethodRequiresLengthButNoContentLengthOrTransferEncodingIsSet(HttpMethod method) + { + using (var input = new TestInput()) + { + input.Http1Connection.Method = method; + var ex = Assert.Throws(() => + Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders(), input.Http1Connection)); + + Assert.Equal(StatusCodes.Status411LengthRequired, ex.StatusCode); + Assert.Equal(CoreStrings.FormatBadRequest_LengthRequired(((IHttpRequestFeature)input.Http1Connection).Method), ex.Message); + } + } + + [Theory] + [InlineData(HttpMethod.Post)] + [InlineData(HttpMethod.Put)] + public void ForThrowsWhenMethodRequiresLengthButNoContentLengthSetHttp10(HttpMethod method) + { + using (var input = new TestInput()) + { + input.Http1Connection.Method = method; + var ex = Assert.Throws(() => + Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders(), input.Http1Connection)); + + Assert.Equal(StatusCodes.Status400BadRequest, ex.StatusCode); + Assert.Equal(CoreStrings.FormatBadRequest_LengthRequiredHttp10(((IHttpRequestFeature)input.Http1Connection).Method), ex.Message); + } + } + + [Fact] + public async Task CopyToAsyncDoesNotCompletePipeReader() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + input.Add("Hello"); + + using (var ms = new MemoryStream()) + { + await body.CopyToAsync(ms); + } + + Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + await body.StopAsync(); + } + } + + [Fact] + public async Task ConsumeAsyncConsumesAllRemainingInput() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http10, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + input.Add("Hello"); + + await body.ConsumeAsync(); + + Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CopyToAsyncDoesNotCopyBlocks() + { + var writeCount = 0; + var writeTcs = new TaskCompletionSource<(byte[], int, int)>(); + var mockDestination = new Mock() { CallBase = true }; + + mockDestination + .Setup(m => m.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), CancellationToken.None)) + .Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + { + writeTcs.SetResult((buffer, offset, count)); + writeCount++; + }) + .Returns(Task.CompletedTask); + + using (var memoryPool = KestrelMemoryPool.Create()) + { + var options = new PipeOptions(pool: memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + var transport = pair.Transport; + var application = pair.Application; + var http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = new TestServiceContext(), + ConnectionFeatures = new FeatureCollection(), + Application = application, + Transport = transport, + MemoryPool = memoryPool, + TimeoutControl = Mock.Of() + }; + var http1Connection = new Http1Connection(http1ConnectionContext) + { + HasStartedConsumingRequestBody = true + }; + + var headers = new HttpRequestHeaders { HeaderContentLength = "12" }; + var body = Http1MessageBody.For(HttpVersion.Http11, headers, http1Connection); + + var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); + + var bytes = Encoding.ASCII.GetBytes("Hello "); + var buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); + ArraySegment segment; + Assert.True(MemoryMarshal.TryGetArray(buffer, out segment)); + Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); + http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); + await http1Connection.RequestBodyPipe.Writer.FlushAsync(); + + // Verify the block passed to Stream.WriteAsync() is the same one incoming data was written into. + Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); + + // Verify the again when GetMemory returns the tail space of the same block. + writeTcs = new TaskCompletionSource<(byte[], int, int)>(); + bytes = Encoding.ASCII.GetBytes("World!"); + buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); + Assert.True(MemoryMarshal.TryGetArray(buffer, out segment)); + Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); + http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); + await http1Connection.RequestBodyPipe.Writer.FlushAsync(); + + Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); + + http1Connection.RequestBodyPipe.Writer.Complete(); + + await copyToAsyncTask; + + Assert.Equal(2, writeCount); + + // Don't call body.StopAsync() because PumpAsync() was never called. + http1Connection.RequestBodyPipe.Reader.Complete(); + } + } + + [Theory] + [InlineData("keep-alive, upgrade")] + [InlineData("Keep-Alive, Upgrade")] + [InlineData("upgrade, keep-alive")] + [InlineData("Upgrade, Keep-Alive")] + public async Task ConnectionUpgradeKeepAlive(string headerConnection) + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = headerConnection }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + Assert.Equal(5, await stream.ReadAsync(buffer, 0, buffer.Length)); + AssertASCII("Hello", new ArraySegment(buffer, 0, 5)); + + input.Fin(); + + await body.StopAsync(); + } + } + + [Fact] + public async Task UpgradeConnectionAcceptsContentLengthZero() + { + // https://tools.ietf.org/html/rfc7230#section-3.3.2 + // "A user agent SHOULD NOT send a Content-Length header field when the request message does not contain + // a payload body and the method semantics do not anticipate such a body." + // ==> it can actually send that header + var headerConnection = "Upgrade, Keep-Alive"; + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = headerConnection, ContentLength = 0 }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + input.Add("Hello"); + + var buffer = new byte[1024]; + Assert.Equal(5, await stream.ReadAsync(buffer, 0, buffer.Length)); + AssertASCII("Hello", new ArraySegment(buffer, 0, 5)); + + input.Fin(); + + await body.StopAsync(); + } + } + + [Fact] + public async Task PumpAsyncDoesNotReturnAfterCancelingInput() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + // Add some input and consume it to ensure PumpAsync is running + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + input.Transport.Input.CancelPendingRead(); + + // Add more input and verify is read + input.Add("b"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + await body.StopAsync(); + } + } + + [Fact] + public async Task StopAsyncPreventsFurtherDataConsumption() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + // Add some input and consume it to ensure PumpAsync is running + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + await body.StopAsync(); + + // Add some more data. Checking for cancelation and exiting the loop + // should take priority over reading this data. + input.Add("b"); + + // There shouldn't be any additional data available + Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1)); + } + } + + [Fact] + public async Task ReadAsyncThrowsOnTimeout() + { + using (var input = new TestInput()) + { + var mockTimeoutControl = new Mock(); + + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + input.Add("a"); + Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + + // Time out on the next read + input.Http1Connection.SendTimeoutResponse(); + + var exception = await Assert.ThrowsAsync(async () => await body.ReadAsync(new Memory(new byte[1]))); + Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); + + await body.StopAsync(); + } + } + + [Fact] + public async Task ConsumeAsyncCompletesAndDoesNotThrowOnTimeout() + { + using (var input = new TestInput()) + { + var mockTimeoutControl = new Mock(); + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var mockLogger = new Mock(); + input.Http1Connection.ServiceContext.Log = mockLogger.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + input.Add("a"); + Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + + // Time out on the next read + input.Http1Connection.SendTimeoutResponse(); + + await body.ConsumeAsync(); + + mockLogger.Verify(logger => logger.ConnectionBadRequest( + It.IsAny(), + It.Is(ex => ex.Reason == RequestRejectionReason.RequestBodyTimeout))); + + await body.StopAsync(); + } + } + + [Fact] + public async Task CopyToAsyncThrowsOnTimeout() + { + using (var input = new TestInput()) + { + var mockTimeoutControl = new Mock(); + + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + input.Add("a"); + Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + + // Time out on the next read + input.Http1Connection.SendTimeoutResponse(); + + using (var ms = new MemoryStream()) + { + var exception = await Assert.ThrowsAsync(() => body.CopyToAsync(ms)); + Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); + } + + await body.StopAsync(); + } + } + + [Fact] + public async Task LogsWhenStartsReadingRequestBody() + { + using (var input = new TestInput()) + { + var mockLogger = new Mock(); + input.Http1Connection.ServiceContext.Log = mockLogger.Object; + input.Http1Connection.ConnectionIdFeature = "ConnectionId"; + input.Http1Connection.TraceIdentifier = "RequestId"; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + // Add some input and consume it to ensure PumpAsync is running + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + mockLogger.Verify(logger => logger.RequestBodyStart("ConnectionId", "RequestId")); + + input.Fin(); + + await body.StopAsync(); + } + } + + [Fact] + public async Task LogsWhenStopsReadingRequestBody() + { + using (var input = new TestInput()) + { + var logEvent = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.RequestBodyDone("ConnectionId", "RequestId")) + .Callback(() => logEvent.SetResult(null)); + input.Http1Connection.ServiceContext.Log = mockLogger.Object; + input.Http1Connection.ConnectionIdFeature = "ConnectionId"; + input.Http1Connection.TraceIdentifier = "RequestId"; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "2" }, input.Http1Connection); + var stream = new HttpRequestStream(Mock.Of()); + stream.StartAcceptingReads(body); + + // Add some input and consume it to ensure PumpAsync is running + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + input.Fin(); + + await logEvent.Task.DefaultTimeout(); + + await body.StopAsync(); + } + } + + [Fact] + public async Task PausesAndResumesRequestBodyTimeoutOnBackpressure() + { + using (var input = new TestInput()) + { + var mockTimeoutControl = new Mock(); + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "12" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + input.Add("hello,"); + Assert.Equal(6, await body.ReadAsync(new ArraySegment(new byte[6]))); + + input.Add(" world"); + Assert.Equal(6, await body.ReadAsync(new ArraySegment(new byte[6]))); + + // Due to the limits set on HttpProtocol.RequestBodyPipe, backpressure should be triggered on every write to that pipe. + mockTimeoutControl.Verify(timeoutControl => timeoutControl.PauseTimingReads(), Times.Exactly(2)); + mockTimeoutControl.Verify(timeoutControl => timeoutControl.ResumeTimingReads(), Times.Exactly(2)); + } + } + + [Fact] + public async Task OnlyEnforcesRequestBodyTimeoutAfterSending100Continue() + { + using (var input = new TestInput()) + { + var produceContinueCalled = false; + var startTimingReadsCalledAfterProduceContinue = false; + + var mockHttpResponseControl = new Mock(); + mockHttpResponseControl + .Setup(httpResponseControl => httpResponseControl.ProduceContinue()) + .Callback(() => produceContinueCalled = true); + input.Http1Connection.HttpResponseControl = mockHttpResponseControl.Object; + + var mockTimeoutControl = new Mock(); + mockTimeoutControl + .Setup(timeoutControl => timeoutControl.StartTimingReads()) + .Callback(() => startTimingReadsCalledAfterProduceContinue = produceContinueCalled); + + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + var readTask = body.ReadAsync(new ArraySegment(new byte[1])); + + Assert.True(startTimingReadsCalledAfterProduceContinue); + + input.Add("a"); + await readTask; + + await body.StopAsync(); + } + } + + [Fact] + public async Task DoesNotEnforceRequestBodyTimeoutOnUpgradeRequests() + { + using (var input = new TestInput()) + { + var mockTimeoutControl = new Mock(); + input.Http1ConnectionContext.TimeoutControl = mockTimeoutControl.Object; + + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderConnection = "upgrade" }, input.Http1Connection); + + // Add some input and read it to start PumpAsync + input.Add("a"); + Assert.Equal(1, await body.ReadAsync(new ArraySegment(new byte[1]))); + + input.Fin(); + + Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + mockTimeoutControl.Verify(timeoutControl => timeoutControl.StartTimingReads(), Times.Never); + mockTimeoutControl.Verify(timeoutControl => timeoutControl.StopTimingReads(), Times.Never); + + // Due to the limits set on HttpProtocol.RequestBodyPipe, backpressure should be triggered on every + // write to that pipe. Verify that read timing pause and resume are not called on upgrade + // requests. + mockTimeoutControl.Verify(timeoutControl => timeoutControl.PauseTimingReads(), Times.Never); + mockTimeoutControl.Verify(timeoutControl => timeoutControl.ResumeTimingReads(), Times.Never); + + await body.StopAsync(); + } + } + + private void AssertASCII(string expected, ArraySegment actual) + { + var encoding = Encoding.ASCII; + var bytes = encoding.GetBytes(expected); + Assert.Equal(bytes.Length, actual.Count); + for (var index = 0; index < bytes.Length; index++) + { + Assert.Equal(bytes[index], actual.Array[actual.Offset + index]); + } + } + + private class ThrowOnWriteSynchronousStream : Stream + { + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new XunitException(); + } + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite => true; + public override long Length { get; } + public override long Position { get; set; } + } + + private class ThrowOnWriteAsynchronousStream : Stream + { + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await Task.Delay(1); + throw new XunitException(); + } + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite => true; + public override long Length { get; } + public override long Position { get; set; } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests.csproj b/src/Servers/Kestrel/Core/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests.csproj new file mode 100644 index 0000000000..e11845c901 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests.csproj @@ -0,0 +1,23 @@ + + + + netcoreapp2.1;net461 + true + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Core/test/MinDataRateTests.cs b/src/Servers/Kestrel/Core/test/MinDataRateTests.cs new file mode 100644 index 0000000000..a87bfa7709 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/MinDataRateTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class MinDataRateTests + { + [Theory] + [InlineData(double.Epsilon)] + [InlineData(double.MaxValue)] + public void BytesPerSecondValid(double value) + { + Assert.Equal(value, new MinDataRate(bytesPerSecond: value, gracePeriod: TimeSpan.MaxValue).BytesPerSecond); + } + + [Theory] + [InlineData(double.MinValue)] + [InlineData(-double.Epsilon)] + [InlineData(0)] + public void BytesPerSecondInvalid(double value) + { + var exception = Assert.Throws(() => new MinDataRate(bytesPerSecond: value, gracePeriod: TimeSpan.MaxValue)); + + Assert.Equal("bytesPerSecond", exception.ParamName); + Assert.StartsWith(CoreStrings.PositiveNumberOrNullMinDataRateRequired, exception.Message); + } + + [Theory] + [MemberData(nameof(GracePeriodValidData))] + public void GracePeriodValid(TimeSpan value) + { + Assert.Equal(value, new MinDataRate(bytesPerSecond: 1, gracePeriod: value).GracePeriod); + } + + [Theory] + [MemberData(nameof(GracePeriodInvalidData))] + public void GracePeriodInvalid(TimeSpan value) + { + var exception = Assert.Throws(() => new MinDataRate(bytesPerSecond: 1, gracePeriod: value)); + + Assert.Equal("gracePeriod", exception.ParamName); + Assert.StartsWith(CoreStrings.FormatMinimumGracePeriodRequired(Heartbeat.Interval.TotalSeconds), exception.Message); + } + + public static TheoryData GracePeriodValidData => new TheoryData + { + Heartbeat.Interval + TimeSpan.FromTicks(1), + TimeSpan.MaxValue + }; + + public static TheoryData GracePeriodInvalidData => new TheoryData + { + TimeSpan.MinValue, + TimeSpan.FromTicks(-1), + TimeSpan.Zero, + Heartbeat.Interval + }; + } +} diff --git a/src/Servers/Kestrel/Core/test/OutputProducerTests.cs b/src/Servers/Kestrel/Core/test/OutputProducerTests.cs new file mode 100644 index 0000000000..7f3d566ff5 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/OutputProducerTests.cs @@ -0,0 +1,102 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class OutputProducerTests : IDisposable + { + private readonly MemoryPool _memoryPool; + + public OutputProducerTests() + { + _memoryPool = KestrelMemoryPool.Create(); + } + + public void Dispose() + { + _memoryPool.Dispose(); + } + + [Fact] + public void WritesNoopAfterConnectionCloses() + { + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: Mock.Of(), + writerScheduler: PipeScheduler.Inline, + useSynchronizationContext: false + ); + + using (var socketOutput = CreateOutputProducer(pipeOptions)) + { + // Close + socketOutput.Dispose(); + + var called = false; + + socketOutput.Write((buffer, state) => + { + called = true; + return 0; + }, + 0); + + Assert.False(called); + } + } + + [Fact] + public void AbortsTransportEvenAfterDispose() + { + var mockConnectionContext = new Mock(); + + var outputProducer = CreateOutputProducer(connectionContext: mockConnectionContext.Object); + + outputProducer.Dispose(); + + mockConnectionContext.Verify(f => f.Abort(It.IsAny()), Times.Never()); + + outputProducer.Abort(null); + + mockConnectionContext.Verify(f => f.Abort(null), Times.Once()); + + outputProducer.Abort(null); + + mockConnectionContext.Verify(f => f.Abort(null), Times.Once()); + } + + private Http1OutputProducer CreateOutputProducer( + PipeOptions pipeOptions = null, + ConnectionContext connectionContext = null) + { + pipeOptions = pipeOptions ?? new PipeOptions(); + connectionContext = connectionContext ?? Mock.Of(); + + var pipe = new Pipe(pipeOptions); + var serviceContext = new TestServiceContext(); + var socketOutput = new Http1OutputProducer( + pipe.Writer, + "0", + connectionContext, + serviceContext.Log, + Mock.Of(), + Mock.Of()); + + return socketOutput; + } + } +} diff --git a/src/Servers/Kestrel/Core/test/PathNormalizerTests.cs b/src/Servers/Kestrel/Core/test/PathNormalizerTests.cs new file mode 100644 index 0000000000..60a1e9b1c9 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/PathNormalizerTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class PathNormalizerTests + { + [Theory] + [InlineData("/a", "/a")] + [InlineData("/a/", "/a/")] + [InlineData("/a/b", "/a/b")] + [InlineData("/a/b/", "/a/b/")] + [InlineData("/./a", "/a")] + [InlineData("/././a", "/a")] + [InlineData("/../a", "/a")] + [InlineData("/../../a", "/a")] + [InlineData("/a/./b", "/a/b")] + [InlineData("/a/../b", "/b")] + [InlineData("/a/./", "/a/")] + [InlineData("/a/.", "/a/")] + [InlineData("/a/../", "/")] + [InlineData("/a/..", "/")] + [InlineData("/a/../b/../", "/")] + [InlineData("/a/../b/..", "/")] + [InlineData("/a/../../b", "/b")] + [InlineData("/a/../../b/", "/b/")] + [InlineData("/a/.././../b", "/b")] + [InlineData("/a/.././../b/", "/b/")] + [InlineData("/a/b/c/./../../d", "/a/d")] + [InlineData("/./a/b/c/./../../d", "/a/d")] + [InlineData("/../a/b/c/./../../d", "/a/d")] + [InlineData("/./../a/b/c/./../../d", "/a/d")] + [InlineData("/.././a/b/c/./../../d", "/a/d")] + [InlineData("/.a", "/.a")] + [InlineData("/..a", "/..a")] + [InlineData("/...", "/...")] + [InlineData("/a/.../b", "/a/.../b")] + [InlineData("/a/../.../../b", "/b")] + [InlineData("/a/.b", "/a/.b")] + [InlineData("/a/..b", "/a/..b")] + [InlineData("/a/b.", "/a/b.")] + [InlineData("/a/b..", "/a/b..")] + [InlineData("/longlong/../short", "/short")] + [InlineData("/short/../longlong", "/longlong")] + [InlineData("/longlong/../short/..", "/")] + [InlineData("/short/../longlong/..", "/")] + [InlineData("/longlong/../short/../", "/")] + [InlineData("/short/../longlong/../", "/")] + [InlineData("/", "/")] + [InlineData("/no/segments", "/no/segments")] + [InlineData("/no/segments/", "/no/segments/")] + public void RemovesDotSegments(string input, string expected) + { + var data = Encoding.ASCII.GetBytes(input); + var length = PathNormalizer.RemoveDotSegments(new Span(data)); + Assert.True(length >= 1); + Assert.Equal(expected, Encoding.ASCII.GetString(data, 0, length)); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/PipeOptionsTests.cs b/src/Servers/Kestrel/Core/test/PipeOptionsTests.cs new file mode 100644 index 0000000000..3643fc44ba --- /dev/null +++ b/src/Servers/Kestrel/Core/test/PipeOptionsTests.cs @@ -0,0 +1,93 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class PipeOptionsTests + { + [Theory] + [InlineData(10, 10, 10)] + [InlineData(0, 1, 1)] + [InlineData(null, 0, 0)] + public void OutputPipeOptionsConfiguredCorrectly(long? maxResponseBufferSize, long expectedMaximumSizeLow, long expectedMaximumSizeHigh) + { + var serviceContext = new TestServiceContext(); + serviceContext.ServerOptions.Limits.MaxResponseBufferSize = maxResponseBufferSize; + serviceContext.Scheduler = PipeScheduler.ThreadPool; + + var mockScheduler = Mock.Of(); + var outputPipeOptions = ConnectionDispatcher.GetOutputPipeOptions(serviceContext, KestrelMemoryPool.Create(), readerScheduler: mockScheduler); + + Assert.Equal(expectedMaximumSizeLow, outputPipeOptions.ResumeWriterThreshold); + Assert.Equal(expectedMaximumSizeHigh, outputPipeOptions.PauseWriterThreshold); + Assert.Same(mockScheduler, outputPipeOptions.ReaderScheduler); + Assert.Same(serviceContext.Scheduler, outputPipeOptions.WriterScheduler); + } + + [Theory] + [InlineData(10, 10, 10)] + [InlineData(null, 0, 0)] + public void InputPipeOptionsConfiguredCorrectly(long? maxRequestBufferSize, long expectedMaximumSizeLow, long expectedMaximumSizeHigh) + { + var serviceContext = new TestServiceContext(); + serviceContext.ServerOptions.Limits.MaxRequestBufferSize = maxRequestBufferSize; + serviceContext.Scheduler = PipeScheduler.ThreadPool; + + var mockScheduler = Mock.Of(); + var inputPipeOptions = ConnectionDispatcher.GetInputPipeOptions(serviceContext, KestrelMemoryPool.Create(), writerScheduler: mockScheduler); + + Assert.Equal(expectedMaximumSizeLow, inputPipeOptions.ResumeWriterThreshold); + Assert.Equal(expectedMaximumSizeHigh, inputPipeOptions.PauseWriterThreshold); + Assert.Same(serviceContext.Scheduler, inputPipeOptions.ReaderScheduler); + Assert.Same(mockScheduler, inputPipeOptions.WriterScheduler); + } + + [Theory] + [InlineData(10, 10, 10)] + [InlineData(null, 0, 0)] + public void AdaptedInputPipeOptionsConfiguredCorrectly(long? maxRequestBufferSize, long expectedMaximumSizeLow, long expectedMaximumSizeHigh) + { + var serviceContext = new TestServiceContext(); + serviceContext.ServerOptions.Limits.MaxRequestBufferSize = maxRequestBufferSize; + + var connectionLifetime = new HttpConnection(new HttpConnectionContext + { + ServiceContext = serviceContext + }); + + Assert.Equal(expectedMaximumSizeLow, connectionLifetime.AdaptedInputPipeOptions.ResumeWriterThreshold); + Assert.Equal(expectedMaximumSizeHigh, connectionLifetime.AdaptedInputPipeOptions.PauseWriterThreshold); + Assert.Same(serviceContext.Scheduler, connectionLifetime.AdaptedInputPipeOptions.ReaderScheduler); + Assert.Same(PipeScheduler.Inline, connectionLifetime.AdaptedInputPipeOptions.WriterScheduler); + } + + [Theory] + [InlineData(10, 10, 10)] + [InlineData(null, 0, 0)] + public void AdaptedOutputPipeOptionsConfiguredCorrectly(long? maxRequestBufferSize, long expectedMaximumSizeLow, long expectedMaximumSizeHigh) + { + var serviceContext = new TestServiceContext(); + serviceContext.ServerOptions.Limits.MaxResponseBufferSize = maxRequestBufferSize; + + var connectionLifetime = new HttpConnection(new HttpConnectionContext + { + ServiceContext = serviceContext + }); + + Assert.Equal(expectedMaximumSizeLow, connectionLifetime.AdaptedOutputPipeOptions.ResumeWriterThreshold); + Assert.Equal(expectedMaximumSizeHigh, connectionLifetime.AdaptedOutputPipeOptions.PauseWriterThreshold); + Assert.Same(PipeScheduler.Inline, connectionLifetime.AdaptedOutputPipeOptions.ReaderScheduler); + Assert.Same(PipeScheduler.Inline, connectionLifetime.AdaptedOutputPipeOptions.WriterScheduler); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/PipelineExtensionTests.cs b/src/Servers/Kestrel/Core/test/PipelineExtensionTests.cs new file mode 100644 index 0000000000..e3a89832da --- /dev/null +++ b/src/Servers/Kestrel/Core/test/PipelineExtensionTests.cs @@ -0,0 +1,180 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class PipelineExtensionTests : IDisposable + { + // ulong.MaxValue.ToString().Length + private const int _ulongMaxValueLength = 20; + + private readonly Pipe _pipe; + private readonly MemoryPool _memoryPool = KestrelMemoryPool.Create(); + + public PipelineExtensionTests() + { + _pipe = new Pipe(new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false)); + } + + public void Dispose() + { + _memoryPool.Dispose(); + } + + [Theory] + [InlineData(ulong.MinValue)] + [InlineData(ulong.MaxValue)] + [InlineData(4_8_15_16_23_42)] + public void WritesNumericToAscii(ulong number) + { + var writerBuffer = _pipe.Writer; + var writer = new BufferWriter(writerBuffer); + writer.WriteNumeric(number); + writer.Commit(); + writerBuffer.FlushAsync().GetAwaiter().GetResult(); + + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + var numAsStr = number.ToString(); + var expected = Encoding.ASCII.GetBytes(numAsStr); + AssertExtensions.Equal(expected, reader.Buffer.Slice(0, numAsStr.Length).ToArray()); + } + + [Theory] + [InlineData(1)] + [InlineData(_ulongMaxValueLength / 2)] + [InlineData(_ulongMaxValueLength - 1)] + public void WritesNumericAcrossSpanBoundaries(int gapSize) + { + var writerBuffer = _pipe.Writer; + var writer = new BufferWriter(writerBuffer); + // almost fill up the first block + var spacer = new byte[writer.Span.Length - gapSize]; + writer.Write(spacer); + + var bufferLength = writer.Span.Length; + writer.WriteNumeric(ulong.MaxValue); + Assert.NotEqual(bufferLength, writer.Span.Length); + writer.Commit(); + writerBuffer.FlushAsync().GetAwaiter().GetResult(); + + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + var numAsString = ulong.MaxValue.ToString(); + var written = reader.Buffer.Slice(spacer.Length, numAsString.Length); + Assert.False(written.IsSingleSegment, "The buffer should cross spans"); + AssertExtensions.Equal(Encoding.ASCII.GetBytes(numAsString), written.ToArray()); + } + + [Theory] + [InlineData("\0abcxyz", new byte[] { 0, 97, 98, 99, 120, 121, 122 })] + [InlineData("!#$%i", new byte[] { 33, 35, 36, 37, 105 })] + [InlineData("!#$%", new byte[] { 33, 35, 36, 37 })] + [InlineData("!#$", new byte[] { 33, 35, 36 })] + [InlineData("!#", new byte[] { 33, 35 })] + [InlineData("!", new byte[] { 33 })] + // null or empty + [InlineData("", new byte[0])] + [InlineData(null, new byte[0])] + public void EncodesAsAscii(string input, byte[] expected) + { + var pipeWriter = _pipe.Writer; + var writer = new BufferWriter(pipeWriter); + writer.WriteAsciiNoValidation(input); + writer.Commit(); + pipeWriter.FlushAsync().GetAwaiter().GetResult(); + pipeWriter.Complete(); + + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + + if (expected.Length > 0) + { + AssertExtensions.Equal( + expected, + reader.Buffer.ToArray()); + } + else + { + Assert.Equal(0, reader.Buffer.Length); + } + } + + [Theory] + // non-ascii characters stored in 32 bits + [InlineData("𤭢𐐝")] + // non-ascii characters stored in 16 bits + [InlineData("ñ٢⛄⛵")] + public void WriteAsciiNoValidationWritesOnlyOneBytePerChar(string input) + { + // WriteAscii doesn't validate if characters are in the ASCII range + // but it shouldn't produce more than one byte per character + var writerBuffer = _pipe.Writer; + var writer = new BufferWriter(writerBuffer); + writer.WriteAsciiNoValidation(input); + writer.Commit(); + writerBuffer.FlushAsync().GetAwaiter().GetResult(); + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + + Assert.Equal(input.Length, reader.Buffer.Length); + } + + [Fact] + public void WriteAsciiNoValidation() + { + const byte maxAscii = 0x7f; + var writerBuffer = _pipe.Writer; + var writer = new BufferWriter(writerBuffer); + for (var i = 0; i < maxAscii; i++) + { + writer.WriteAsciiNoValidation(new string((char)i, 1)); + } + writer.Commit(); + writerBuffer.FlushAsync().GetAwaiter().GetResult(); + + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + var data = reader.Buffer.Slice(0, maxAscii).ToArray(); + for (var i = 0; i < maxAscii; i++) + { + Assert.Equal(i, data[i]); + } + } + + [Theory] + [InlineData(2, 1)] + [InlineData(3, 1)] + [InlineData(4, 2)] + [InlineData(5, 3)] + [InlineData(7, 4)] + [InlineData(8, 3)] + [InlineData(8, 4)] + [InlineData(8, 5)] + [InlineData(100, 48)] + public void WritesAsciiAcrossBlockBoundaries(int stringLength, int gapSize) + { + var testString = new string(' ', stringLength); + var writerBuffer = _pipe.Writer; + var writer = new BufferWriter(writerBuffer); + // almost fill up the first block + var spacer = new byte[writer.Span.Length - gapSize]; + writer.Write(spacer); + Assert.Equal(gapSize, writer.Span.Length); + + var bufferLength = writer.Span.Length; + writer.WriteAsciiNoValidation(testString); + Assert.NotEqual(bufferLength, writer.Span.Length); + writer.Commit(); + writerBuffer.FlushAsync().GetAwaiter().GetResult(); + + var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + var written = reader.Buffer.Slice(spacer.Length, stringLength); + Assert.False(written.IsSingleSegment, "The buffer should cross spans"); + AssertExtensions.Equal(Encoding.ASCII.GetBytes(testString), written.ToArray()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/ReasonPhrasesTests.cs b/src/Servers/Kestrel/Core/test/ReasonPhrasesTests.cs new file mode 100644 index 0000000000..9a7957735b --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ReasonPhrasesTests.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Xunit; +using Microsoft.AspNetCore.Http; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ReasonPhraseTests + { + [Theory] + [InlineData(999, "Unknown", "999 Unknown")] + [InlineData(999, null, "999 ")] + [InlineData(StatusCodes.Status200OK, "OK", "200 OK")] + [InlineData(StatusCodes.Status200OK, null, "200 OK")] + public void Formatting(int statusCode, string reasonPhrase, string expectedResult) + { + var bytes = Internal.Http.ReasonPhrases.ToStatusBytes(statusCode, reasonPhrase); + Assert.NotNull(bytes); + Assert.Equal(expectedResult, Encoding.ASCII.GetString(bytes)); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Core/test/ResourceCounterTests.cs b/src/Servers/Kestrel/Core/test/ResourceCounterTests.cs new file mode 100644 index 0000000000..118b195248 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ResourceCounterTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ResourceCounterTests + { + [Theory] + [InlineData(-1)] + [InlineData(long.MinValue)] + public void QuotaInvalid(long max) + { + Assert.Throws(() => ResourceCounter.Quota(max)); + } + + [Fact] + public void QuotaAcceptsUpToButNotMoreThanMax() + { + var counter = ResourceCounter.Quota(1); + Assert.True(counter.TryLockOne()); + Assert.False(counter.TryLockOne()); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(10)] + [InlineData(100)] + public void QuotaValid(long max) + { + var counter = ResourceCounter.Quota(max); + Parallel.For(0, max, i => + { + Assert.True(counter.TryLockOne()); + }); + + Parallel.For(0, 10, i => + { + Assert.False(counter.TryLockOne()); + }); + } + + [Fact] + public void QuotaDoesNotWrapAround() + { + var counter = new ResourceCounter.FiniteCounter(long.MaxValue); + counter.Count = long.MaxValue; + Assert.False(counter.TryLockOne()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/ServerAddressTests.cs b/src/Servers/Kestrel/Core/test/ServerAddressTests.cs new file mode 100644 index 0000000000..94dc2ee3ab --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ServerAddressTests.cs @@ -0,0 +1,69 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ServerAddressTests + { + [Theory] + [InlineData("")] + [InlineData("5000")] + [InlineData("//noscheme")] + public void FromUriThrowsForUrlsWithoutSchemeDelimiter(string url) + { + Assert.Throws(() => ServerAddress.FromUrl(url)); + } + + [Theory] + [InlineData("://")] + [InlineData("://:5000")] + [InlineData("http://")] + [InlineData("http://:5000")] + [InlineData("http:///")] + [InlineData("http:///:5000")] + [InlineData("http:////")] + [InlineData("http:////:5000")] + public void FromUriThrowsForUrlsWithoutHost(string url) + { + Assert.Throws(() => ServerAddress.FromUrl(url)); + } + + [Theory] + [InlineData("://emptyscheme", "", "emptyscheme", 0, "", "://emptyscheme:0")] + [InlineData("http://+", "http", "+", 80, "", "http://+:80")] + [InlineData("http://*", "http", "*", 80, "", "http://*:80")] + [InlineData("http://localhost", "http", "localhost", 80, "", "http://localhost:80")] + [InlineData("http://www.example.com", "http", "www.example.com", 80, "", "http://www.example.com:80")] + [InlineData("https://www.example.com", "https", "www.example.com", 443, "", "https://www.example.com:443")] + [InlineData("http://www.example.com/", "http", "www.example.com", 80, "", "http://www.example.com:80")] + [InlineData("http://www.example.com/foo?bar=baz", "http", "www.example.com", 80, "/foo?bar=baz", "http://www.example.com:80")] + [InlineData("http://www.example.com:5000", "http", "www.example.com", 5000, "", null)] + [InlineData("https://www.example.com:5000", "https", "www.example.com", 5000, "", null)] + [InlineData("http://www.example.com:5000/", "http", "www.example.com", 5000, "", "http://www.example.com:5000")] + [InlineData("http://www.example.com:NOTAPORT", "http", "www.example.com:NOTAPORT", 80, "", "http://www.example.com:notaport:80")] + [InlineData("https://www.example.com:NOTAPORT", "https", "www.example.com:NOTAPORT", 443, "", "https://www.example.com:notaport:443")] + [InlineData("http://www.example.com:NOTAPORT/", "http", "www.example.com:NOTAPORT", 80, "", "http://www.example.com:notaport:80")] + [InlineData("http://foo:/tmp/kestrel-test.sock:5000/doesn't/matter", "http", "foo:", 80, "/tmp/kestrel-test.sock:5000/doesn't/matter", "http://foo::80")] + [InlineData("http://unix:foo/tmp/kestrel-test.sock", "http", "unix:foo", 80, "/tmp/kestrel-test.sock", "http://unix:foo:80")] + [InlineData("http://unix:5000/tmp/kestrel-test.sock", "http", "unix", 5000, "/tmp/kestrel-test.sock", "http://unix:5000")] + [InlineData("http://unix:/tmp/kestrel-test.sock", "http", "unix:/tmp/kestrel-test.sock", 0, "", null)] + [InlineData("https://unix:/tmp/kestrel-test.sock", "https", "unix:/tmp/kestrel-test.sock", 0, "", null)] + [InlineData("http://unix:/tmp/kestrel-test.sock:", "http", "unix:/tmp/kestrel-test.sock", 0, "", "http://unix:/tmp/kestrel-test.sock")] + [InlineData("http://unix:/tmp/kestrel-test.sock:/", "http", "unix:/tmp/kestrel-test.sock", 0, "", "http://unix:/tmp/kestrel-test.sock")] + [InlineData("http://unix:/tmp/kestrel-test.sock:5000/doesn't/matter", "http", "unix:/tmp/kestrel-test.sock", 0, "5000/doesn't/matter", "http://unix:/tmp/kestrel-test.sock")] + public void UrlsAreParsedCorrectly(string url, string scheme, string host, int port, string pathBase, string toString) + { + var serverAddress = ServerAddress.FromUrl(url); + + Assert.Equal(scheme, serverAddress.Scheme); + Assert.Equal(host, serverAddress.Host); + Assert.Equal(port, serverAddress.Port); + Assert.Equal(pathBase, serverAddress.PathBase); + + Assert.Equal(toString ?? url, serverAddress.ToString()); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/StreamsTests.cs b/src/Servers/Kestrel/Core/test/StreamsTests.cs new file mode 100644 index 0000000000..e80b745640 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/StreamsTests.cs @@ -0,0 +1,84 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class StreamsTests + { + [Fact] + public async Task StreamsThrowAfterAbort() + { + var streams = new Streams(Mock.Of(), Mock.Of()); + var (request, response) = streams.Start(new MockMessageBody()); + + var ex = new Exception("My error"); + streams.Abort(ex); + + await response.WriteAsync(new byte[1], 0, 1); + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + } + + [Fact] + public async Task StreamsThrowOnAbortAfterUpgrade() + { + var streams = new Streams(Mock.Of(), Mock.Of()); + var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); + + var upgrade = streams.Upgrade(); + var ex = new Exception("My error"); + streams.Abort(ex); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + [Fact] + public async Task StreamsThrowOnUpgradeAfterAbort() + { + var streams = new Streams(Mock.Of(), Mock.Of()); + + var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); + var ex = new Exception("My error"); + streams.Abort(ex); + + var upgrade = streams.Upgrade(); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + private class MockMessageBody : Http1MessageBody + { + public MockMessageBody(bool upgradeable = false) + : base(null) + { + RequestUpgrade = upgradeable; + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/StringUtilitiesTests.cs b/src/Servers/Kestrel/Core/test/StringUtilitiesTests.cs new file mode 100644 index 0000000000..5ef5f30cc1 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/StringUtilitiesTests.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class StringUtilitiesTests + { + [Theory] + [InlineData(uint.MinValue)] + [InlineData(0xF)] + [InlineData(0xA)] + [InlineData(0xFF)] + [InlineData(0xFFC59)] + [InlineData(uint.MaxValue)] + public void ConvertsToHex(uint value) + { + var str = CorrelationIdGenerator.GetNextId(); + Assert.Equal($"{str}:{value:X8}", StringUtilities.ConcatAsHexSuffix(str, ':', value)); + } + + [Fact] + public void HandlesNull() + { + uint value = 0x23BC0234; + Assert.Equal(":23BC0234", StringUtilities.ConcatAsHexSuffix(null, ':', value)); + } + } +} diff --git a/src/Servers/Kestrel/Core/test/TestHelpers/AssertExtensions.cs b/src/Servers/Kestrel/Core/test/TestHelpers/AssertExtensions.cs new file mode 100644 index 0000000000..cb3fc36a3c --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TestHelpers/AssertExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Xunit.Sdk; + +namespace Xunit +{ + public static class AssertExtensions + { + public static void Equal(byte[] expected, Span actual) + { + if (expected.Length != actual.Length) + { + throw new XunitException($"Expected length to be {expected.Length} but was {actual.Length}"); + } + + for (var i = 0; i < expected.Length; i++) + { + if (expected[i] != actual[i]) + { + throw new XunitException($@"Expected byte at index {i} to be '{expected[i]}' but was '{actual[i]}'"); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/test/TestHelpers/MockHttpResponseControl.cs b/src/Servers/Kestrel/Core/test/TestHelpers/MockHttpResponseControl.cs new file mode 100644 index 0000000000..738a070635 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TestHelpers/MockHttpResponseControl.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests.TestHelpers +{ + public class MockHttpResponseControl : IHttpResponseControl + { + public Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public void ProduceContinue() + { + } + + public Task WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } +} diff --git a/src/Servers/Kestrel/Core/test/TestInput.cs b/src/Servers/Kestrel/Core/test/TestInput.cs new file mode 100644 index 0000000000..ef0bf7613c --- /dev/null +++ b/src/Servers/Kestrel/Core/test/TestInput.cs @@ -0,0 +1,82 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Moq; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + class TestInput : IDisposable + { + private MemoryPool _memoryPool; + + public TestInput() + { + _memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(pool: _memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + Transport = pair.Transport; + Application = pair.Application; + + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + + Http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = new TestServiceContext(), + ConnectionContext = Mock.Of(), + ConnectionFeatures = connectionFeatures, + Application = Application, + Transport = Transport, + MemoryPool = _memoryPool, + TimeoutControl = Mock.Of() + }; + + Http1Connection = new Http1Connection(Http1ConnectionContext); + Http1Connection.HttpResponseControl = Mock.Of(); + } + + public IDuplexPipe Transport { get; } + + public IDuplexPipe Application { get; } + + public Http1ConnectionContext Http1ConnectionContext { get; } + + public Http1Connection Http1Connection { get; set; } + + public void Add(string text) + { + var data = Encoding.ASCII.GetBytes(text); + async Task Write() => await Application.Output.WriteAsync(data); + Write().Wait(); + } + + public void Fin() + { + Application.Output.Complete(); + } + + public void Cancel() + { + Transport.Input.CancelPendingRead(); + } + + public void Dispose() + { + _memoryPool.Dispose(); + } + } +} + diff --git a/src/Servers/Kestrel/Core/test/ThrowingWriteOnlyStreamTests.cs b/src/Servers/Kestrel/Core/test/ThrowingWriteOnlyStreamTests.cs new file mode 100644 index 0000000000..7e6490e4f3 --- /dev/null +++ b/src/Servers/Kestrel/Core/test/ThrowingWriteOnlyStreamTests.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ThrowingWriteOnlyStreamTests + { + [Fact] + public async Task ThrowsOnWrite() + { + var ex = new Exception("my error"); + var stream = new ThrowingWriteOnlyStream(ex); + + Assert.True(stream.CanWrite); + Assert.False(stream.CanRead); + Assert.False(stream.CanSeek); + Assert.False(stream.CanTimeout); + Assert.Same(ex, Assert.Throws(() => stream.Write(new byte[1], 0, 1))); + Assert.Same(ex, await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], 0, 1))); + Assert.Same(ex, Assert.Throws(() => stream.Flush())); + Assert.Same(ex, await Assert.ThrowsAsync(() => stream.FlushAsync())); + } + } +} diff --git a/src/Servers/Kestrel/Directory.Build.props b/src/Servers/Kestrel/Directory.Build.props new file mode 100644 index 0000000000..ad68f6ce2a --- /dev/null +++ b/src/Servers/Kestrel/Directory.Build.props @@ -0,0 +1,25 @@ + + + + + $(DefineConstants);INNER_LOOP + + + + + false + $(MSBuildThisFileDirectory)shared\ + + + + + true + + + + + + diff --git a/src/Servers/Kestrel/Https/src/Microsoft.AspNetCore.Server.Kestrel.Https.csproj b/src/Servers/Kestrel/Https/src/Microsoft.AspNetCore.Server.Kestrel.Https.csproj new file mode 100644 index 0000000000..9e6fde40bd --- /dev/null +++ b/src/Servers/Kestrel/Https/src/Microsoft.AspNetCore.Server.Kestrel.Https.csproj @@ -0,0 +1,16 @@ + + + + HTTPS support for the ASP.NET Core Kestrel cross-platform web server. + netstandard2.0;netcoreapp2.1 + true + aspnetcore;kestrel + CS1591;$(NoWarn) + + + + + + + + diff --git a/src/Servers/Kestrel/Https/src/Properties/AssemblyInfo.cs b/src/Servers/Kestrel/Https/src/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..c99dd8d3e3 --- /dev/null +++ b/src/Servers/Kestrel/Https/src/Properties/AssemblyInfo.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; + +[assembly: TypeForwardedTo(typeof(ClientCertificateMode))] +[assembly: TypeForwardedTo(typeof(HttpsConnectionAdapter))] +[assembly: TypeForwardedTo(typeof(HttpsConnectionAdapterOptions))] +[assembly: TypeForwardedTo(typeof(ListenOptionsHttpsExtensions))] \ No newline at end of file diff --git a/src/Servers/Kestrel/Https/src/baseline.netcore.json b/src/Servers/Kestrel/Https/src/baseline.netcore.json new file mode 100644 index 0000000000..4919ecc705 --- /dev/null +++ b/src/Servers/Kestrel/Https/src/baseline.netcore.json @@ -0,0 +1,249 @@ +{ + "AssemblyIdentity": "Microsoft.AspNetCore.Server.Kestrel.Https, Version=2.0.2.0, Culture=neutral, PublicKeyToken=adb9793829ddae60", + "Types": [ + { + "Name": "Microsoft.AspNetCore.Hosting.ListenOptionsHttpsExtensions", + "Visibility": "Public", + "Kind": "Class", + "Abstract": true, + "Static": true, + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "UseHttps", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + }, + { + "Name": "fileName", + "Type": "System.String" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseHttps", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + }, + { + "Name": "fileName", + "Type": "System.String" + }, + { + "Name": "password", + "Type": "System.String" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseHttps", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + }, + { + "Name": "serverCertificate", + "Type": "System.Security.Cryptography.X509Certificates.X509Certificate2" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseHttps", + "Parameters": [ + { + "Name": "listenOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions" + }, + { + "Name": "httpsOptions", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Https.ClientCertificateMode", + "Visibility": "Public", + "Kind": "Enumeration", + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Field", + "Name": "NoCertificate", + "Parameters": [], + "GenericParameter": [], + "Literal": "0" + }, + { + "Kind": "Field", + "Name": "AllowCertificate", + "Parameters": [], + "GenericParameter": [], + "Literal": "1" + }, + { + "Kind": "Field", + "Name": "RequireCertificate", + "Parameters": [], + "GenericParameter": [], + "Literal": "2" + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_ServerCertificate", + "Parameters": [], + "ReturnType": "System.Security.Cryptography.X509Certificates.X509Certificate2", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ServerCertificate", + "Parameters": [ + { + "Name": "value", + "Type": "System.Security.Cryptography.X509Certificates.X509Certificate2" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_ClientCertificateMode", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Https.ClientCertificateMode", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ClientCertificateMode", + "Parameters": [ + { + "Name": "value", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Https.ClientCertificateMode" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_ClientCertificateValidation", + "Parameters": [], + "ReturnType": "System.Func", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ClientCertificateValidation", + "Parameters": [ + { + "Name": "value", + "Type": "System.Func" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_SslProtocols", + "Parameters": [], + "ReturnType": "System.Security.Authentication.SslProtocols", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_SslProtocols", + "Parameters": [ + { + "Name": "value", + "Type": "System.Security.Authentication.SslProtocols" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_CheckCertificateRevocation", + "Parameters": [], + "ReturnType": "System.Boolean", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_CheckCertificateRevocation", + "Parameters": [ + { + "Name": "value", + "Type": "System.Boolean" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + } + ] +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Kestrel/src/Microsoft.AspNetCore.Server.Kestrel.csproj b/src/Servers/Kestrel/Kestrel/src/Microsoft.AspNetCore.Server.Kestrel.csproj new file mode 100644 index 0000000000..6b8de4426f --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/src/Microsoft.AspNetCore.Server.Kestrel.csproj @@ -0,0 +1,18 @@ + + + + ASP.NET Core Kestrel cross-platform web server. + netstandard2.0 + true + aspnetcore;kestrel + CS1591;$(NoWarn) + + + + + + + + + + diff --git a/src/Servers/Kestrel/Kestrel/src/WebHostBuilderKestrelExtensions.cs b/src/Servers/Kestrel/Kestrel/src/WebHostBuilderKestrelExtensions.cs new file mode 100644 index 0000000000..6552da10f2 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/src/WebHostBuilderKestrelExtensions.cs @@ -0,0 +1,85 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class WebHostBuilderKestrelExtensions + { + /// + /// Specify Kestrel as the server to be used by the web host. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseKestrel(this IWebHostBuilder hostBuilder) + { + return hostBuilder.ConfigureServices(services => + { + // Don't override an already-configured transport + services.TryAddSingleton(); + + services.AddTransient, KestrelServerOptionsSetup>(); + services.AddSingleton(); + }); + } + + /// + /// Specify Kestrel as the server to be used by the web host. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// A callback to configure Kestrel options. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseKestrel(this IWebHostBuilder hostBuilder, Action options) + { + return hostBuilder.UseKestrel().ConfigureServices(services => + { + services.Configure(options); + }); + } + + /// + /// Specify Kestrel as the server to be used by the web host. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// A callback to configure Kestrel options. + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseKestrel(this IWebHostBuilder hostBuilder, Action configureOptions) + { + if (configureOptions == null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + + return hostBuilder.UseKestrel().ConfigureServices((context, services) => + { + services.Configure(options => + { + configureOptions(context, options); + }); + }); + } + } +} diff --git a/src/Servers/Kestrel/Kestrel/src/baseline.netcore.json b/src/Servers/Kestrel/Kestrel/src/baseline.netcore.json new file mode 100644 index 0000000000..7f71b30042 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/src/baseline.netcore.json @@ -0,0 +1,51 @@ +{ + "AssemblyIdentity": "Microsoft.AspNetCore.Server.Kestrel, Version=2.0.2.0, Culture=neutral, PublicKeyToken=adb9793829ddae60", + "Types": [ + { + "Name": "Microsoft.AspNetCore.Hosting.WebHostBuilderKestrelExtensions", + "Visibility": "Public", + "Kind": "Class", + "Abstract": true, + "Static": true, + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "UseKestrel", + "Parameters": [ + { + "Name": "hostBuilder", + "Type": "Microsoft.AspNetCore.Hosting.IWebHostBuilder" + } + ], + "ReturnType": "Microsoft.AspNetCore.Hosting.IWebHostBuilder", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseKestrel", + "Parameters": [ + { + "Name": "hostBuilder", + "Type": "Microsoft.AspNetCore.Hosting.IWebHostBuilder" + }, + { + "Name": "options", + "Type": "System.Action" + } + ], + "ReturnType": "Microsoft.AspNetCore.Hosting.IWebHostBuilder", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + } + ] +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs b/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs new file mode 100644 index 0000000000..ecc7f5e587 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs @@ -0,0 +1,177 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.Extensions.Configuration; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tests +{ + public class ConfigurationReaderTests + { + [Fact] + public void ReadCertificatesWhenNoCertificatsSection_ReturnsEmptyCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection().Build(); + var reader = new ConfigurationReader(config); + var certificates = reader.Certificates; + Assert.NotNull(certificates); + Assert.False(certificates.Any()); + } + + [Fact] + public void ReadCertificatesWhenEmptyCertificatsSection_ReturnsEmptyCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Certificates", ""), + }).Build(); + var reader = new ConfigurationReader(config); + var certificates = reader.Certificates; + Assert.NotNull(certificates); + Assert.False(certificates.Any()); + } + + [Fact] + public void ReadCertificatsSection_ReturnsCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Certificates:FileCert:Path", "/path/cert.pfx"), + new KeyValuePair("Certificates:FileCert:Password", "certpassword"), + new KeyValuePair("Certificates:StoreCert:Subject", "certsubject"), + new KeyValuePair("Certificates:StoreCert:Store", "certstore"), + new KeyValuePair("Certificates:StoreCert:Location", "cetlocation"), + new KeyValuePair("Certificates:StoreCert:AllowInvalid", "true"), + }).Build(); + var reader = new ConfigurationReader(config); + var certificates = reader.Certificates; + Assert.NotNull(certificates); + Assert.Equal(2, certificates.Count); + + var fileCert = certificates["FileCert"]; + Assert.True(fileCert.IsFileCert); + Assert.False(fileCert.IsStoreCert); + Assert.Equal("/path/cert.pfx", fileCert.Path); + Assert.Equal("certpassword", fileCert.Password); + + var storeCert = certificates["StoreCert"]; + Assert.False(storeCert.IsFileCert); + Assert.True(storeCert.IsStoreCert); + Assert.Equal("certsubject", storeCert.Subject); + Assert.Equal("certstore", storeCert.Store); + Assert.Equal("cetlocation", storeCert.Location); + Assert.True(storeCert.AllowInvalid); + } + + [Fact] + public void ReadEndpointsWhenNoEndpointsSection_ReturnsEmptyCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection().Build(); + var reader = new ConfigurationReader(config); + var endpoints = reader.Endpoints; + Assert.NotNull(endpoints); + Assert.False(endpoints.Any()); + } + + [Fact] + public void ReadEndpointsWhenEmptyEndpointsSection_ReturnsEmptyCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints", ""), + }).Build(); + var reader = new ConfigurationReader(config); + var endpoints = reader.Endpoints; + Assert.NotNull(endpoints); + Assert.False(endpoints.Any()); + } + + [Fact] + public void ReadEndpointWithMissingUrl_Throws() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1", ""), + }).Build(); + var reader = new ConfigurationReader(config); + Assert.Throws(() => reader.Endpoints); + } + + [Fact] + public void ReadEndpointWithEmptyUrl_Throws() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", ""), + }).Build(); + var reader = new ConfigurationReader(config); + Assert.Throws(() => reader.Endpoints); + } + + [Fact] + public void ReadEndpointsSection_ReturnsCollection() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", "http://*:5001"), + new KeyValuePair("Endpoints:End2:Url", "https://*:5002"), + new KeyValuePair("Endpoints:End3:Url", "https://*:5003"), + new KeyValuePair("Endpoints:End3:Certificate:Path", "/path/cert.pfx"), + new KeyValuePair("Endpoints:End3:Certificate:Password", "certpassword"), + new KeyValuePair("Endpoints:End4:Url", "https://*:5004"), + new KeyValuePair("Endpoints:End4:Certificate:Subject", "certsubject"), + new KeyValuePair("Endpoints:End4:Certificate:Store", "certstore"), + new KeyValuePair("Endpoints:End4:Certificate:Location", "cetlocation"), + new KeyValuePair("Endpoints:End4:Certificate:AllowInvalid", "true"), + }).Build(); + var reader = new ConfigurationReader(config); + var endpoints = reader.Endpoints; + Assert.NotNull(endpoints); + Assert.Equal(4, endpoints.Count()); + + var end1 = endpoints.First(); + Assert.Equal("End1", end1.Name); + Assert.Equal("http://*:5001", end1.Url); + Assert.NotNull(end1.ConfigSection); + Assert.NotNull(end1.Certificate); + Assert.False(end1.Certificate.ConfigSection.Exists()); + + var end2 = endpoints.Skip(1).First(); + Assert.Equal("End2", end2.Name); + Assert.Equal("https://*:5002", end2.Url); + Assert.NotNull(end2.ConfigSection); + Assert.NotNull(end2.Certificate); + Assert.False(end2.Certificate.ConfigSection.Exists()); + + var end3 = endpoints.Skip(2).First(); + Assert.Equal("End3", end3.Name); + Assert.Equal("https://*:5003", end3.Url); + Assert.NotNull(end3.ConfigSection); + Assert.NotNull(end3.Certificate); + Assert.True(end3.Certificate.ConfigSection.Exists()); + var cert3 = end3.Certificate; + Assert.True(cert3.IsFileCert); + Assert.False(cert3.IsStoreCert); + Assert.Equal("/path/cert.pfx", cert3.Path); + Assert.Equal("certpassword", cert3.Password); + + var end4 = endpoints.Skip(3).First(); + Assert.Equal("End4", end4.Name); + Assert.Equal("https://*:5004", end4.Url); + Assert.NotNull(end4.ConfigSection); + Assert.NotNull(end4.Certificate); + Assert.True(end4.Certificate.ConfigSection.Exists()); + var cert4 = end4.Certificate; + Assert.False(cert4.IsFileCert); + Assert.True(cert4.IsStoreCert); + Assert.Equal("certsubject", cert4.Subject); + Assert.Equal("certstore", cert4.Store); + Assert.Equal("cetlocation", cert4.Location); + Assert.True(cert4.AllowInvalid); + } + } +} diff --git a/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs b/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs new file mode 100644 index 0000000000..33caeb1d34 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs @@ -0,0 +1,328 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tests +{ + public class KestrelConfigurationBuilderTests + { + private KestrelServerOptions CreateServerOptions() + { + var serverOptions = new KestrelServerOptions(); + serverOptions.ApplicationServices = new ServiceCollection() + .AddLogging() + .AddSingleton(new HostingEnvironment() { ApplicationName = "TestApplication" }) + .BuildServiceProvider(); + return serverOptions; + } + + [Fact] + public void ConfigureNamedEndpoint_OnlyRunForMatchingConfig() + { + var found = false; + var serverOptions = CreateServerOptions(); + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:Found:Url", "http://*:5001"), + }).Build(); + serverOptions.Configure(config) + .Endpoint("Found", endpointOptions => found = true) + .Endpoint("NotFound", endpointOptions => throw new NotImplementedException()) + .Load(); + + Assert.Single(serverOptions.ListenOptions); + Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + + Assert.True(found); + } + + [Fact] + public void ConfigureEndpoint_OnlyRunWhenBuildIsCalled() + { + var run = false; + var serverOptions = CreateServerOptions(); + serverOptions.Configure() + .LocalhostEndpoint(5001, endpointOptions => run = true); + + Assert.Empty(serverOptions.ListenOptions); + + serverOptions.ConfigurationLoader.Load(); + + Assert.Single(serverOptions.ListenOptions); + Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + + Assert.True(run); + } + + [Fact] + public void CallBuildTwice_OnlyRunsOnce() + { + var serverOptions = CreateServerOptions(); + var builder = serverOptions.Configure() + .LocalhostEndpoint(5001); + + Assert.Empty(serverOptions.ListenOptions); + Assert.Equal(builder, serverOptions.ConfigurationLoader); + + builder.Load(); + + Assert.Single(serverOptions.ListenOptions); + Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Null(serverOptions.ConfigurationLoader); + + builder.Load(); + + Assert.Single(serverOptions.ListenOptions); + Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Null(serverOptions.ConfigurationLoader); + } + + [Fact] + public void Configure_IsReplacable() + { + var run1 = false; + var serverOptions = CreateServerOptions(); + var config1 = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", "http://*:5001"), + }).Build(); + serverOptions.Configure(config1) + .LocalhostEndpoint(5001, endpointOptions => run1 = true); + + Assert.Empty(serverOptions.ListenOptions); + Assert.False(run1); + + var run2 = false; + var config2 = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End2:Url", "http://*:5002"), + }).Build(); + serverOptions.Configure(config2) + .LocalhostEndpoint(5003, endpointOptions => run2 = true); + + serverOptions.ConfigurationLoader.Load(); + + Assert.Equal(2, serverOptions.ListenOptions.Count); + Assert.Equal(5002, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Equal(5003, serverOptions.ListenOptions[1].IPEndPoint.Port); + + Assert.False(run1); + Assert.True(run2); + } + + [Fact] + public void ConfigureDefaultsAppliesToNewConfigureEndpoints() + { + var serverOptions = CreateServerOptions(); + + serverOptions.ConfigureEndpointDefaults(opt => + { + opt.NoDelay = false; + }); + + serverOptions.ConfigureHttpsDefaults(opt => + { + opt.ServerCertificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + opt.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + }); + + var ran1 = false; + var ran2 = false; + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", "https://*:5001"), + }).Build(); + serverOptions.Configure(config) + .Endpoint("End1", opt => + { + ran1 = true; + Assert.True(opt.IsHttps); + Assert.NotNull(opt.HttpsOptions.ServerCertificate); + Assert.Equal(ClientCertificateMode.RequireCertificate, opt.HttpsOptions.ClientCertificateMode); + Assert.False(opt.ListenOptions.NoDelay); + }) + .LocalhostEndpoint(5002, opt => + { + ran2 = true; + Assert.False(opt.NoDelay); + }) + .Load(); + + Assert.True(ran1); + Assert.True(ran2); + + Assert.NotNull(serverOptions.ListenOptions[0].ConnectionAdapters.Where(adapter => adapter.IsHttps).SingleOrDefault()); + Assert.Null(serverOptions.ListenOptions[1].ConnectionAdapters.Where(adapter => adapter.IsHttps).SingleOrDefault()); + } + + [Fact] + public void ConfigureEndpointDefaultCanEnableHttps() + { + var serverOptions = CreateServerOptions(); + + serverOptions.ConfigureEndpointDefaults(opt => + { + opt.NoDelay = false; + opt.UseHttps(new X509Certificate2(TestResources.TestCertificatePath, "testPassword")); + }); + + serverOptions.ConfigureHttpsDefaults(opt => + { + opt.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + }); + + var ran1 = false; + var ran2 = false; + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", "https://*:5001"), + }).Build(); + serverOptions.Configure(config) + .Endpoint("End1", opt => + { + ran1 = true; + Assert.True(opt.IsHttps); + Assert.Equal(ClientCertificateMode.RequireCertificate, opt.HttpsOptions.ClientCertificateMode); + Assert.False(opt.ListenOptions.NoDelay); + }) + .LocalhostEndpoint(5002, opt => + { + ran2 = true; + Assert.False(opt.NoDelay); + }) + .Load(); + + Assert.True(ran1); + Assert.True(ran2); + + // You only get Https once per endpoint. + Assert.NotNull(serverOptions.ListenOptions[0].ConnectionAdapters.Where(adapter => adapter.IsHttps).SingleOrDefault()); + Assert.NotNull(serverOptions.ListenOptions[1].ConnectionAdapters.Where(adapter => adapter.IsHttps).SingleOrDefault()); + } + + [Fact] + public void ConfigureEndpointDevelopmentCertificateGetsLoadedWhenPresent() + { + try + { + var serverOptions = CreateServerOptions(); + var certificate = new X509Certificate2(TestResources.GetCertPath("aspnetdevcert.pfx"), "aspnetdevcert", X509KeyStorageFlags.Exportable); + var bytes = certificate.Export(X509ContentType.Pkcs12, "1234"); + var path = GetCertificatePath(); + Directory.CreateDirectory(Path.GetDirectoryName(path)); + File.WriteAllBytes(path, bytes); + + var ran1 = false; + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:End1:Url", "https://*:5001"), + new KeyValuePair("Certificates:Development:Password", "1234"), + }).Build(); + + serverOptions + .Configure(config) + .Endpoint("End1", opt => + { + ran1 = true; + Assert.True(opt.IsHttps); + Assert.Equal(opt.HttpsOptions.ServerCertificate.SerialNumber, certificate.SerialNumber); + }).Load(); + + Assert.True(ran1); + Assert.NotNull(serverOptions.DefaultCertificate); + } + finally + { + if (File.Exists(GetCertificatePath())) + { + File.Delete(GetCertificatePath()); + } + } + } + + [Fact] + public void ConfigureEndpointDevelopmentCertificateGetsIgnoredIfPasswordIsNotCorrect() + { + try + { + var serverOptions = CreateServerOptions(); + var certificate = new X509Certificate2(TestResources.GetCertPath("aspnetdevcert.pfx"), "aspnetdevcert", X509KeyStorageFlags.Exportable); + var bytes = certificate.Export(X509ContentType.Pkcs12, "1234"); + var path = GetCertificatePath(); + Directory.CreateDirectory(Path.GetDirectoryName(path)); + File.WriteAllBytes(path, bytes); + + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Certificates:Development:Password", "12341234"), + }).Build(); + + serverOptions + .Configure(config) + .Load(); + + Assert.Null(serverOptions.DefaultCertificate); + } + finally + { + if (File.Exists(GetCertificatePath())) + { + File.Delete(GetCertificatePath()); + } + } + } + + [Fact] + public void ConfigureEndpointDevelopmentCertificateGetsIgnoredIfPfxFileDoesNotExist() + { + try + { + var serverOptions = CreateServerOptions(); + if (File.Exists(GetCertificatePath())) + { + File.Delete(GetCertificatePath()); + } + + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Certificates:Development:Password", "12341234") + }).Build(); + + serverOptions + .Configure(config) + .Load(); + + Assert.Null(serverOptions.DefaultCertificate); + } + finally + { + if (File.Exists(GetCertificatePath())) + { + File.Delete(GetCertificatePath()); + } + } + } + + private static string GetCertificatePath() + { + var appData = Environment.GetEnvironmentVariable("APPDATA"); + var home = Environment.GetEnvironmentVariable("HOME"); + var basePath = appData != null ? Path.Combine(appData, "ASP.NET", "https") : null; + basePath = basePath ?? (home != null ? Path.Combine(home, ".aspnet", "https") : null); + return Path.Combine(basePath, $"TestApplication.pfx"); + } + } +} diff --git a/src/Servers/Kestrel/Kestrel/test/Microsoft.AspNetCore.Server.Kestrel.Tests.csproj b/src/Servers/Kestrel/Kestrel/test/Microsoft.AspNetCore.Server.Kestrel.Tests.csproj new file mode 100644 index 0000000000..9bb414d321 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/test/Microsoft.AspNetCore.Server.Kestrel.Tests.csproj @@ -0,0 +1,17 @@ + + + + netcoreapp2.1;net461 + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Kestrel/test/WebHostBuilderKestrelExtensionsTests.cs b/src/Servers/Kestrel/Kestrel/test/WebHostBuilderKestrelExtensionsTests.cs new file mode 100644 index 0000000000..7ea16d4c85 --- /dev/null +++ b/src/Servers/Kestrel/Kestrel/test/WebHostBuilderKestrelExtensionsTests.cs @@ -0,0 +1,99 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tests +{ + public class WebHostBuilderKestrelExtensionsTests + { + [Fact] + public void ApplicationServicesNotNullAfterUseKestrelWithoutOptions() + { + // Arrange + var hostBuilder = new WebHostBuilder() + .UseKestrel() + .Configure(app => { }); + + hostBuilder.ConfigureServices(services => + { + services.Configure(options => + { + // Assert + Assert.NotNull(options.ApplicationServices); + }); + }); + + // Act + hostBuilder.Build(); + } + + [Fact] + public void ApplicationServicesNotNullDuringUseKestrelWithOptions() + { + // Arrange + var hostBuilder = new WebHostBuilder() + .UseKestrel(options => + { + // Assert + Assert.NotNull(options.ApplicationServices); + }) + .Configure(app => { }); + + // Act + hostBuilder.Build(); + } + + [Fact] + public void SocketTransportIsTheDefault() + { + var hostBuilder = new WebHostBuilder() + .UseKestrel() + .Configure(app => { }); + + Assert.IsType(hostBuilder.Build().Services.GetService()); + } + + [Fact] + public void LibuvTransportCanBeManuallySelectedIndependentOfOrder() + { + var hostBuilder = new WebHostBuilder() + .UseKestrel() + .UseLibuv() + .Configure(app => { }); + + Assert.IsType(hostBuilder.Build().Services.GetService()); + + var hostBuilderReversed = new WebHostBuilder() + .UseLibuv() + .UseKestrel() + .Configure(app => { }); + + Assert.IsType(hostBuilderReversed.Build().Services.GetService()); + } + + [Fact] + public void SocketsTransportCanBeManuallySelectedIndependentOfOrder() + { + var hostBuilder = new WebHostBuilder() + .UseKestrel() + .UseSockets() + .Configure(app => { }); + + Assert.IsType(hostBuilder.Build().Services.GetService()); + + var hostBuilderReversed = new WebHostBuilder() + .UseSockets() + .UseKestrel() + .Configure(app => { }); + + Assert.IsType(hostBuilderReversed.Build().Services.GetService()); + } + } +} diff --git a/src/Servers/Kestrel/README.md b/src/Servers/Kestrel/README.md new file mode 100644 index 0000000000..5aaf71c466 --- /dev/null +++ b/src/Servers/Kestrel/README.md @@ -0,0 +1,8 @@ +KestrelHttpServer +================= + +Kestrel is a cross-platform web server for ASP.NET Core. + +## File logging for functional test + +Turn on file logging for Kestrel functional tests by specifying the environment variable ASPNETCORE_TEST_LOG_DIR to the log output directory. diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/FileHandleType.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/FileHandleType.cs new file mode 100644 index 0000000000..bb70e4ec34 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/FileHandleType.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + /// + /// Enumerates the types. + /// + public enum FileHandleType + { + Auto, + Tcp, + Pipe + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IApplicationTransportFeature.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IApplicationTransportFeature.cs new file mode 100644 index 0000000000..490cb7f065 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IApplicationTransportFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface IApplicationTransportFeature + { + IDuplexPipe Application { get; set; } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IBytesWrittenFeature.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IBytesWrittenFeature.cs new file mode 100644 index 0000000000..e4bf998f37 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IBytesWrittenFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface IBytesWrittenFeature + { + long TotalBytesWritten { get; } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IConnectionDispatcher.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IConnectionDispatcher.cs new file mode 100644 index 0000000000..dbcd7411aa --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IConnectionDispatcher.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface IConnectionDispatcher + { + void OnConnection(TransportConnection connection); + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IEndPointInformation.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IEndPointInformation.cs new file mode 100644 index 0000000000..1b7abfa497 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/IEndPointInformation.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface IEndPointInformation + { + /// + /// The type of interface being described: either an , Unix domain socket path, or a file descriptor. + /// + ListenType Type { get; } + + // IPEndPoint is mutable so port 0 can be updated to the bound port. + /// + /// The to bind to. + /// Only set if is . + /// + IPEndPoint IPEndPoint { get; set; } + + /// + /// The absolute path to a Unix domain socket to bind to. + /// Only set if is . + /// + string SocketPath { get; } + + /// + /// A file descriptor for the socket to open. + /// Only set if is . + /// + ulong FileHandle { get; } + + // HandleType is mutable so it can be re-specified later. + /// + /// The type of file descriptor being used. + /// Only set if is . + /// + FileHandleType HandleType { get; set; } + + /// + /// Set to false to enable Nagle's algorithm for all connections. + /// + bool NoDelay { get; } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransport.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransport.cs new file mode 100644 index 0000000000..5a6dc0c20c --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransport.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface ITransport + { + // Can only be called once per ITransport + Task BindAsync(); + Task UnbindAsync(); + Task StopAsync(); + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportFactory.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportFactory.cs new file mode 100644 index 0000000000..4037467e87 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportFactory.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface ITransportFactory + { + ITransport Create(IEndPointInformation endPointInformation, IConnectionDispatcher dispatcher); + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportSchedulerFeature.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportSchedulerFeature.cs new file mode 100644 index 0000000000..be113bbe10 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ITransportSchedulerFeature.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public interface ITransportSchedulerFeature + { + PipeScheduler InputWriterScheduler { get; } + + PipeScheduler OutputReaderScheduler { get; } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/KestrelMemoryPool.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/KestrelMemoryPool.cs new file mode 100644 index 0000000000..f53c15d543 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/KestrelMemoryPool.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public static class KestrelMemoryPool + { + public static MemoryPool Create() => new SlabMemoryPool(); + + public static readonly int MinimumSegmentSize = 4096; + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ListenType.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ListenType.cs new file mode 100644 index 0000000000..3616f1967e --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/ListenType.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + /// + /// Enumerates the types. + /// + public enum ListenType + { + IPEndPoint, + SocketPath, + FileHandle + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Debug.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Debug.cs new file mode 100644 index 0000000000..6f815716d4 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Debug.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#if DEBUG + +using System.Threading; +using System.Diagnostics; + +namespace System.Buffers +{ + /// + /// Block tracking object used by the byte buffer memory pool. A slab is a large allocation which is divided into smaller blocks. The + /// individual blocks are then treated as independent array segments. + /// + internal sealed class MemoryPoolBlock : MemoryManager + { + private readonly int _offset; + private readonly int _length; + + private int _pinCount; + + /// + /// This object cannot be instantiated outside of the static Create method + /// + internal MemoryPoolBlock(SlabMemoryPool pool, MemoryPoolSlab slab, int offset, int length) + { + _offset = offset; + _length = length; + + Pool = pool; + Slab = slab; + } + + /// + /// Back-reference to the memory pool which this block was allocated from. It may only be returned to this pool. + /// + public SlabMemoryPool Pool { get; } + + /// + /// Back-reference to the slab from which this block was taken, or null if it is one-time-use memory. + /// + public MemoryPoolSlab Slab { get; } + + public override Memory Memory + { + get + { + if (!Slab.IsActive) ThrowHelper.ThrowObjectDisposedException(ExceptionArgument.MemoryPoolBlock); + + return CreateMemory(_length); + } + } + + +#if BLOCK_LEASE_TRACKING + public bool IsLeased { get; set; } + public string Leaser { get; set; } +#endif + + ~MemoryPoolBlock() + { + if (Slab != null && Slab.IsActive) + { + Debug.Assert(false, $"{Environment.NewLine}{Environment.NewLine}*** Block being garbage collected instead of returned to pool" + +#if BLOCK_LEASE_TRACKING + $": {Leaser}" + +#endif + $" ***{ Environment.NewLine}"); + + // Need to make a new object because this one is being finalized + Pool.Return(new MemoryPoolBlock(Pool, Slab, _offset, _length)); + } + } + + protected override void Dispose(bool disposing) + { + if (!Slab.IsActive) ThrowHelper.ThrowObjectDisposedException(ExceptionArgument.MemoryPoolBlock); + + if (Volatile.Read(ref _pinCount) > 0) + { + ThrowHelper.ThrowInvalidOperationException_ReturningPinnedBlock(); + } + + Pool.Return(this); + } + + public override Span GetSpan() => new Span(Slab.Array, _offset, _length); + + public override MemoryHandle Pin(int byteOffset = 0) + { + if (!Slab.IsActive) ThrowHelper.ThrowObjectDisposedException(ExceptionArgument.MemoryPoolBlock); + if (byteOffset < 0 || byteOffset > _length) ThrowHelper.ThrowArgumentOutOfRangeException(_length, byteOffset); + + Interlocked.Increment(ref _pinCount); + unsafe + { + return new MemoryHandle((Slab.NativePointer + _offset + byteOffset).ToPointer(), default, this); + } + } + + protected override bool TryGetArray(out ArraySegment segment) + { + segment = new ArraySegment(Slab.Array, _offset, _length); + return true; + } + + public override void Unpin() + { + if (Interlocked.Decrement(ref _pinCount) < 0) + { + ThrowHelper.ThrowInvalidOperationException_ReferenceCountZero(); + } + } + + public void Lease() + { +#if BLOCK_LEASE_TRACKING + Leaser = Environment.StackTrace; + IsLeased = true; +#endif + } + } +} + +#endif \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Release.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Release.cs new file mode 100644 index 0000000000..93784df05e --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolBlock.Release.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#if RELEASE + +using System.Runtime.InteropServices; + +namespace System.Buffers +{ + /// + /// Block tracking object used by the byte buffer memory pool. A slab is a large allocation which is divided into smaller blocks. The + /// individual blocks are then treated as independent array segments. + /// + internal sealed class MemoryPoolBlock : IMemoryOwner + { + private readonly int _offset; + private readonly int _length; + + /// + /// This object cannot be instantiated outside of the static Create method + /// + internal MemoryPoolBlock(SlabMemoryPool pool, MemoryPoolSlab slab, int offset, int length) + { + _offset = offset; + _length = length; + + Pool = pool; + Slab = slab; + + Memory = MemoryMarshal.CreateFromPinnedArray(slab.Array, _offset, _length); + } + + /// + /// Back-reference to the memory pool which this block was allocated from. It may only be returned to this pool. + /// + public SlabMemoryPool Pool { get; } + + /// + /// Back-reference to the slab from which this block was taken, or null if it is one-time-use memory. + /// + public MemoryPoolSlab Slab { get; } + + public Memory Memory { get; } + + ~MemoryPoolBlock() + { + if (Slab != null && Slab.IsActive) + { + // Need to make a new object because this one is being finalized + Pool.Return(new MemoryPoolBlock(Pool, Slab, _offset, _length)); + } + } + + public void Dispose() + { + Pool.Return(this); + } + + public void Lease() + { + } + } +} + +#endif diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolSlab.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolSlab.cs new file mode 100644 index 0000000000..4b15f88233 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/MemoryPoolSlab.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.InteropServices; + +namespace System.Buffers +{ + /// + /// Slab tracking object used by the byte buffer memory pool. A slab is a large allocation which is divided into smaller blocks. The + /// individual blocks are then treated as independant array segments. + /// + internal class MemoryPoolSlab : IDisposable + { + /// + /// This handle pins the managed array in memory until the slab is disposed. This prevents it from being + /// relocated and enables any subsections of the array to be used as native memory pointers to P/Invoked API calls. + /// + private readonly GCHandle _gcHandle; + private readonly IntPtr _nativePointer; + private byte[] _data; + + private bool _isActive; + private bool _disposedValue; + + public MemoryPoolSlab(byte[] data) + { + _data = data; + _gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); + _nativePointer = _gcHandle.AddrOfPinnedObject(); + _isActive = true; + } + + /// + /// True as long as the blocks from this slab are to be considered returnable to the pool. In order to shrink the + /// memory pool size an entire slab must be removed. That is done by (1) setting IsActive to false and removing the + /// slab from the pool's _slabs collection, (2) as each block currently in use is Return()ed to the pool it will + /// be allowed to be garbage collected rather than re-pooled, and (3) when all block tracking objects are garbage + /// collected and the slab is no longer references the slab will be garbage collected and the memory unpinned will + /// be unpinned by the slab's Dispose. + /// + public bool IsActive => _isActive; + + public IntPtr NativePointer => _nativePointer; + + public byte[] Array => _data; + + public int Length => _data.Length; + + public static MemoryPoolSlab Create(int length) + { + // allocate and pin requested memory length + var array = new byte[length]; + + // allocate and return slab tracking object + return new MemoryPoolSlab(array); + } + + protected void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + // N/A: dispose managed state (managed objects). + } + + _isActive = false; + + if (_gcHandle.IsAllocated) + { + _gcHandle.Free(); + } + + // set large fields to null. + _data = null; + + _disposedValue = true; + } + } + + // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. + ~MemoryPoolSlab() + { + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + Dispose(false); + } + + // This code added to correctly implement the disposable pattern. + public void Dispose() + { + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + Dispose(true); + // uncomment the following line if the finalizer is overridden above. + GC.SuppressFinalize(this); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SchedulingMode.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SchedulingMode.cs new file mode 100644 index 0000000000..881006087c --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SchedulingMode.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public enum SchedulingMode + { + Default, + ThreadPool, + Inline + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SlabMemoryPool.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SlabMemoryPool.cs new file mode 100644 index 0000000000..e36ec70adb --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/SlabMemoryPool.cs @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Concurrent; +using System.Diagnostics; + +namespace System.Buffers +{ + /// + /// Used to allocate and distribute re-usable blocks of memory. + /// + internal class SlabMemoryPool : MemoryPool + { + /// + /// The size of a block. 4096 is chosen because most operating systems use 4k pages. + /// + private const int _blockSize = 4096; + + /// + /// Allocating 32 contiguous blocks per slab makes the slab size 128k. This is larger than the 85k size which will place the memory + /// in the large object heap. This means the GC will not try to relocate this array, so the fact it remains pinned does not negatively + /// affect memory management's compactification. + /// + private const int _blockCount = 32; + + /// + /// Max allocation block size for pooled blocks, + /// larger values can be leased but they will be disposed after use rather than returned to the pool. + /// + public override int MaxBufferSize { get; } = _blockSize; + + /// + /// 4096 * 32 gives you a slabLength of 128k contiguous bytes allocated per slab + /// + private static readonly int _slabLength = _blockSize * _blockCount; + + /// + /// Thread-safe collection of blocks which are currently in the pool. A slab will pre-allocate all of the block tracking objects + /// and add them to this collection. When memory is requested it is taken from here first, and when it is returned it is re-added. + /// + private readonly ConcurrentQueue _blocks = new ConcurrentQueue(); + + /// + /// Thread-safe collection of slabs which have been allocated by this pool. As long as a slab is in this collection and slab.IsActive, + /// the blocks will be added to _blocks when returned. + /// + private readonly ConcurrentStack _slabs = new ConcurrentStack(); + + /// + /// This is part of implementing the IDisposable pattern. + /// + private bool _disposedValue = false; // To detect redundant calls + + /// + /// This default value passed in to Rent to use the default value for the pool. + /// + private const int AnySize = -1; + + public override IMemoryOwner Rent(int size = AnySize) + { + if (size == AnySize) size = _blockSize; + else if (size > _blockSize) + { + ThrowHelper.ThrowArgumentOutOfRangeException_BufferRequestTooLarge(_blockSize); + } + + var block = Lease(); + return block; + } + + /// + /// Called to take a block from the pool. + /// + /// The block that is reserved for the called. It must be passed to Return when it is no longer being used. + private MemoryPoolBlock Lease() + { + Debug.Assert(!_disposedValue, "Block being leased from disposed pool!"); + + if (_blocks.TryDequeue(out MemoryPoolBlock block)) + { + // block successfully taken from the stack - return it + + block.Lease(); + return block; + } + // no blocks available - grow the pool + block = AllocateSlab(); + block.Lease(); + return block; + } + + /// + /// Internal method called when a block is requested and the pool is empty. It allocates one additional slab, creates all of the + /// block tracking objects, and adds them all to the pool. + /// + private MemoryPoolBlock AllocateSlab() + { + var slab = MemoryPoolSlab.Create(_slabLength); + _slabs.Push(slab); + + var basePtr = slab.NativePointer; + // Page align the blocks + var firstOffset = (int)((((ulong)basePtr + (uint)_blockSize - 1) & ~((uint)_blockSize - 1)) - (ulong)basePtr); + // Ensure page aligned + Debug.Assert((((ulong)basePtr + (uint)firstOffset) & (uint)(_blockSize - 1)) == 0); + + var blockAllocationLength = ((_slabLength - firstOffset) & ~(_blockSize - 1)); + var offset = firstOffset; + for (; + offset + _blockSize < blockAllocationLength; + offset += _blockSize) + { + var block = new MemoryPoolBlock( + this, + slab, + offset, + _blockSize); +#if BLOCK_LEASE_TRACKING + block.IsLeased = true; +#endif + Return(block); + } + + Debug.Assert(offset + _blockSize - firstOffset == blockAllocationLength); + // return last block rather than adding to pool + var newBlock = new MemoryPoolBlock( + this, + slab, + offset, + _blockSize); + + return newBlock; + } + + /// + /// Called to return a block to the pool. Once Return has been called the memory no longer belongs to the caller, and + /// Very Bad Things will happen if the memory is read of modified subsequently. If a caller fails to call Return and the + /// block tracking object is garbage collected, the block tracking object's finalizer will automatically re-create and return + /// a new tracking object into the pool. This will only happen if there is a bug in the server, however it is necessary to avoid + /// leaving "dead zones" in the slab due to lost block tracking objects. + /// + /// The block to return. It must have been acquired by calling Lease on the same memory pool instance. + internal void Return(MemoryPoolBlock block) + { +#if BLOCK_LEASE_TRACKING + Debug.Assert(block.Pool == this, "Returned block was not leased from this pool"); + Debug.Assert(block.IsLeased, $"Block being returned to pool twice: {block.Leaser}{Environment.NewLine}"); + block.IsLeased = false; +#endif + + if (block.Slab != null && block.Slab.IsActive) + { + _blocks.Enqueue(block); + } + else + { + GC.SuppressFinalize(block); + } + } + + protected override void Dispose(bool disposing) + { + if (!_disposedValue) + { + _disposedValue = true; +#if DEBUG && !INNER_LOOP + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); +#endif + if (disposing) + { + while (_slabs.TryPop(out MemoryPoolSlab slab)) + { + // dispose managed state (managed objects). + slab.Dispose(); + } + } + + // Discard blocks in pool + while (_blocks.TryDequeue(out MemoryPoolBlock block)) + { + GC.SuppressFinalize(block); + } + + // N/A: free unmanaged resources (unmanaged objects) and override a finalizer below. + + // N/A: set large fields to null. + + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.Features.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.Features.cs new file mode 100644 index 0000000000..4b65762f30 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.Features.cs @@ -0,0 +1,405 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net; +using System.Threading; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public partial class TransportConnection : IFeatureCollection, + IHttpConnectionFeature, + IConnectionIdFeature, + IConnectionTransportFeature, + IConnectionItemsFeature, + IMemoryPoolFeature, + IApplicationTransportFeature, + ITransportSchedulerFeature, + IConnectionLifetimeFeature, + IBytesWrittenFeature + { + private static readonly Type IHttpConnectionFeatureType = typeof(IHttpConnectionFeature); + private static readonly Type IConnectionIdFeatureType = typeof(IConnectionIdFeature); + private static readonly Type IConnectionTransportFeatureType = typeof(IConnectionTransportFeature); + private static readonly Type IConnectionItemsFeatureType = typeof(IConnectionItemsFeature); + private static readonly Type IMemoryPoolFeatureType = typeof(IMemoryPoolFeature); + private static readonly Type IApplicationTransportFeatureType = typeof(IApplicationTransportFeature); + private static readonly Type ITransportSchedulerFeatureType = typeof(ITransportSchedulerFeature); + private static readonly Type IConnectionLifetimeFeatureType = typeof(IConnectionLifetimeFeature); + private static readonly Type IBytesWrittenFeatureType = typeof(IBytesWrittenFeature); + + private object _currentIHttpConnectionFeature; + private object _currentIConnectionIdFeature; + private object _currentIConnectionTransportFeature; + private object _currentIConnectionItemsFeature; + private object _currentIMemoryPoolFeature; + private object _currentIApplicationTransportFeature; + private object _currentITransportSchedulerFeature; + private object _currentIConnectionLifetimeFeature; + private object _currentIBytesWrittenFeature; + + private int _featureRevision; + + private List> MaybeExtra; + + private object ExtraFeatureGet(Type key) + { + if (MaybeExtra == null) + { + return null; + } + for (var i = 0; i < MaybeExtra.Count; i++) + { + var kv = MaybeExtra[i]; + if (kv.Key == key) + { + return kv.Value; + } + } + return null; + } + + private void ExtraFeatureSet(Type key, object value) + { + if (MaybeExtra == null) + { + MaybeExtra = new List>(2); + } + + for (var i = 0; i < MaybeExtra.Count; i++) + { + if (MaybeExtra[i].Key == key) + { + MaybeExtra[i] = new KeyValuePair(key, value); + return; + } + } + MaybeExtra.Add(new KeyValuePair(key, value)); + } + + bool IFeatureCollection.IsReadOnly => false; + + int IFeatureCollection.Revision => _featureRevision; + + string IHttpConnectionFeature.ConnectionId + { + get => ConnectionId; + set => ConnectionId = value; + } + + IPAddress IHttpConnectionFeature.RemoteIpAddress + { + get => RemoteAddress; + set => RemoteAddress = value; + } + + IPAddress IHttpConnectionFeature.LocalIpAddress + { + get => LocalAddress; + set => LocalAddress = value; + } + + int IHttpConnectionFeature.RemotePort + { + get => RemotePort; + set => RemotePort = value; + } + + int IHttpConnectionFeature.LocalPort + { + get => LocalPort; + set => LocalPort = value; + } + + MemoryPool IMemoryPoolFeature.MemoryPool => MemoryPool; + + IDuplexPipe IConnectionTransportFeature.Transport + { + get => Transport; + set => Transport = value; + } + + IDuplexPipe IApplicationTransportFeature.Application + { + get => Application; + set => Application = value; + } + + IDictionary IConnectionItemsFeature.Items + { + get => Items; + set => Items = value; + } + + CancellationToken IConnectionLifetimeFeature.ConnectionClosed + { + get => ConnectionClosed; + set => ConnectionClosed = value; + } + + void IConnectionLifetimeFeature.Abort() => Abort(); + + long IBytesWrittenFeature.TotalBytesWritten => TotalBytesWritten; + + PipeScheduler ITransportSchedulerFeature.InputWriterScheduler => InputWriterScheduler; + PipeScheduler ITransportSchedulerFeature.OutputReaderScheduler => OutputReaderScheduler; + + object IFeatureCollection.this[Type key] + { + get + { + if (key == IHttpConnectionFeatureType) + { + return _currentIHttpConnectionFeature; + } + + if (key == IConnectionIdFeatureType) + { + return _currentIConnectionIdFeature; + } + + if (key == IConnectionTransportFeatureType) + { + return _currentIConnectionTransportFeature; + } + + if (key == IConnectionItemsFeatureType) + { + return _currentIConnectionItemsFeature; + } + + if (key == IMemoryPoolFeatureType) + { + return _currentIMemoryPoolFeature; + } + + if (key == IApplicationTransportFeatureType) + { + return _currentIApplicationTransportFeature; + } + + if (key == ITransportSchedulerFeatureType) + { + return _currentITransportSchedulerFeature; + } + + if (key == IConnectionLifetimeFeatureType) + { + return _currentIConnectionLifetimeFeature; + } + + if (key == IBytesWrittenFeatureType) + { + return _currentIBytesWrittenFeature; + } + + if (MaybeExtra != null) + { + return ExtraFeatureGet(key); + } + + return null; + } + set + { + _featureRevision++; + + if (key == IHttpConnectionFeatureType) + { + _currentIHttpConnectionFeature = value; + } + else if (key == IConnectionIdFeatureType) + { + _currentIConnectionIdFeature = value; + } + else if (key == IConnectionTransportFeatureType) + { + _currentIConnectionTransportFeature = value; + } + else if (key == IConnectionItemsFeatureType) + { + _currentIConnectionItemsFeature = value; + } + else if (key == IMemoryPoolFeatureType) + { + _currentIMemoryPoolFeature = value; + } + else if (key == IApplicationTransportFeatureType) + { + _currentIApplicationTransportFeature = value; + } + else if (key == ITransportSchedulerFeatureType) + { + _currentITransportSchedulerFeature = value; + } + else if (key == IConnectionLifetimeFeatureType) + { + _currentIConnectionLifetimeFeature = value; + } + else if (key == IBytesWrittenFeatureType) + { + _currentIBytesWrittenFeature = value; + } + else + { + ExtraFeatureSet(key, value); + } + } + } + + TFeature IFeatureCollection.Get() + { + if (typeof(TFeature) == typeof(IHttpConnectionFeature)) + { + return (TFeature)_currentIHttpConnectionFeature; + } + else if (typeof(TFeature) == typeof(IConnectionIdFeature)) + { + return (TFeature)_currentIConnectionIdFeature; + } + else if (typeof(TFeature) == typeof(IConnectionTransportFeature)) + { + return (TFeature)_currentIConnectionTransportFeature; + } + else if (typeof(TFeature) == typeof(IConnectionItemsFeature)) + { + return (TFeature)_currentIConnectionItemsFeature; + } + else if (typeof(TFeature) == typeof(IMemoryPoolFeature)) + { + return (TFeature)_currentIMemoryPoolFeature; + } + else if (typeof(TFeature) == typeof(IApplicationTransportFeature)) + { + return (TFeature)_currentIApplicationTransportFeature; + } + else if (typeof(TFeature) == typeof(ITransportSchedulerFeature)) + { + return (TFeature)_currentITransportSchedulerFeature; + } + else if (typeof(TFeature) == typeof(IConnectionLifetimeFeature)) + { + return (TFeature)_currentIConnectionLifetimeFeature; + } + else if (typeof(TFeature) == typeof(IBytesWrittenFeature)) + { + return (TFeature)_currentIBytesWrittenFeature; + } + else if (MaybeExtra != null) + { + return (TFeature)ExtraFeatureGet(typeof(TFeature)); + } + + return default; + } + + void IFeatureCollection.Set(TFeature instance) + { + _featureRevision++; + + if (typeof(TFeature) == typeof(IHttpConnectionFeature)) + { + _currentIHttpConnectionFeature = instance; + } + else if (typeof(TFeature) == typeof(IConnectionIdFeature)) + { + _currentIConnectionIdFeature = instance; + } + else if (typeof(TFeature) == typeof(IConnectionTransportFeature)) + { + _currentIConnectionTransportFeature = instance; + } + else if (typeof(TFeature) == typeof(IConnectionItemsFeature)) + { + _currentIConnectionItemsFeature = instance; + } + else if (typeof(TFeature) == typeof(IMemoryPoolFeature)) + { + _currentIMemoryPoolFeature = instance; + } + else if (typeof(TFeature) == typeof(IApplicationTransportFeature)) + { + _currentIApplicationTransportFeature = instance; + } + else if (typeof(TFeature) == typeof(ITransportSchedulerFeature)) + { + _currentITransportSchedulerFeature = instance; + } + else if (typeof(TFeature) == typeof(IConnectionLifetimeFeature)) + { + _currentIConnectionLifetimeFeature = instance; + } + else if (typeof(TFeature) == typeof(IBytesWrittenFeature)) + { + _currentIBytesWrittenFeature = instance; + } + else + { + ExtraFeatureSet(typeof(TFeature), instance); + } + } + + IEnumerator> IEnumerable>.GetEnumerator() => FastEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => FastEnumerable().GetEnumerator(); + + private IEnumerable> FastEnumerable() + { + if (_currentIHttpConnectionFeature != null) + { + yield return new KeyValuePair(IHttpConnectionFeatureType, _currentIHttpConnectionFeature); + } + + if (_currentIConnectionIdFeature != null) + { + yield return new KeyValuePair(IConnectionIdFeatureType, _currentIConnectionIdFeature); + } + + if (_currentIConnectionTransportFeature != null) + { + yield return new KeyValuePair(IConnectionTransportFeatureType, _currentIConnectionTransportFeature); + } + + if (_currentIConnectionItemsFeature != null) + { + yield return new KeyValuePair(IConnectionItemsFeatureType, _currentIConnectionItemsFeature); + } + + if (_currentIMemoryPoolFeature != null) + { + yield return new KeyValuePair(IMemoryPoolFeatureType, _currentIMemoryPoolFeature); + } + + if (_currentIApplicationTransportFeature != null) + { + yield return new KeyValuePair(IApplicationTransportFeatureType, _currentIApplicationTransportFeature); + } + + if (_currentITransportSchedulerFeature != null) + { + yield return new KeyValuePair(ITransportSchedulerFeatureType, _currentITransportSchedulerFeature); + } + + if (_currentIConnectionLifetimeFeature != null) + { + yield return new KeyValuePair(IConnectionLifetimeFeatureType, _currentIConnectionLifetimeFeature); + } + + if (_currentIBytesWrittenFeature != null) + { + yield return new KeyValuePair(IBytesWrittenFeatureType, _currentIBytesWrittenFeature); + } + + if (MaybeExtra != null) + { + foreach (var item in MaybeExtra) + { + yield return item; + } + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.cs b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.cs new file mode 100644 index 0000000000..d96ceea0ec --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Internal/TransportConnection.cs @@ -0,0 +1,74 @@ +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Net; +using System.Threading; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal +{ + public partial class TransportConnection : ConnectionContext + { + private IDictionary _items; + + public TransportConnection() + { + _currentIConnectionIdFeature = this; + _currentIConnectionTransportFeature = this; + _currentIHttpConnectionFeature = this; + _currentIConnectionItemsFeature = this; + _currentIApplicationTransportFeature = this; + _currentIMemoryPoolFeature = this; + _currentITransportSchedulerFeature = this; + _currentIConnectionLifetimeFeature = this; + _currentIBytesWrittenFeature = this; + } + + public IPAddress RemoteAddress { get; set; } + public int RemotePort { get; set; } + public IPAddress LocalAddress { get; set; } + public int LocalPort { get; set; } + + public override string ConnectionId { get; set; } + + public override IFeatureCollection Features => this; + + public virtual MemoryPool MemoryPool { get; } + public virtual PipeScheduler InputWriterScheduler { get; } + public virtual PipeScheduler OutputReaderScheduler { get; } + public virtual long TotalBytesWritten { get; } + + public override IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + + public override IDictionary Items + { + get + { + // Lazily allocate connection metadata + return _items ?? (_items = new ConnectionItems()); + } + set + { + _items = value; + } + } + + public PipeWriter Input => Application.Output; + public PipeReader Output => Application.Input; + + public CancellationToken ConnectionClosed { get; set; } + + // DO NOT remove this override to ConnectionContext.Abort. Doing so would cause + // any TransportConnection that does not override Abort or calls base.Abort + // to stack overflow when IConnectionLifetimeFeature.Abort() is called. + // That said, all derived types should override this method should override + // this implementation of Abort because canceling pending output reads is not + // sufficient to abort the connection if there is backpressure. + public override void Abort(ConnectionAbortedException abortReason) + { + Output.CancelPendingRead(); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.csproj b/src/Servers/Kestrel/Transport.Abstractions/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.csproj new file mode 100644 index 0000000000..fe131c9dd6 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.csproj @@ -0,0 +1,21 @@ + + + + Transport abstractions for the ASP.NET Core Kestrel cross-platform web server. + netstandard2.0 + true + aspnetcore;kestrel + CS1570;CS1571;CS1572;CS1573;CS1574;CS1591;$(NoWarn) + true + false + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Transport.Abstractions/src/baseline.netcore.json b/src/Servers/Kestrel/Transport.Abstractions/src/baseline.netcore.json new file mode 100644 index 0000000000..15710d1a41 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Abstractions/src/baseline.netcore.json @@ -0,0 +1,4 @@ +{ + "AssemblyIdentity": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions, Version=2.0.2.0, Culture=neutral, PublicKeyToken=adb9793829ddae60", + "Types": [] +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/IAsyncDisposable.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/IAsyncDisposable.cs new file mode 100644 index 0000000000..8c98c2127c --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/IAsyncDisposable.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + interface IAsyncDisposable + { + Task DisposeAsync(); + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/ILibuvTrace.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ILibuvTrace.cs new file mode 100644 index 0000000000..46512ba482 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ILibuvTrace.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public interface ILibuvTrace : ILogger + { + void ConnectionRead(string connectionId, int count); + + void ConnectionReadFin(string connectionId); + + void ConnectionWriteFin(string connectionId); + + void ConnectionWroteFin(string connectionId, int status); + + void ConnectionWrite(string connectionId, int count); + + void ConnectionWriteCallback(string connectionId, int status); + + void ConnectionError(string connectionId, Exception ex); + + void ConnectionReset(string connectionId); + + void ConnectionPause(string connectionId); + + void ConnectionResume(string connectionId); + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvAwaitable.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvAwaitable.cs new file mode 100644 index 0000000000..b4d40dea29 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvAwaitable.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvAwaitable : ICriticalNotifyCompletion where TRequest : UvRequest + { + private readonly static Action _callbackCompleted = () => { }; + + private Action _callback; + + private UvException _exception; + + private int _status; + + public static readonly Action Callback = (req, status, error, state) => + { + var awaitable = (LibuvAwaitable)state; + + awaitable._exception = error; + awaitable._status = status; + + var continuation = Interlocked.Exchange(ref awaitable._callback, _callbackCompleted); + + continuation?.Invoke(); + }; + + public LibuvAwaitable GetAwaiter() => this; + public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); + + public UvWriteResult GetResult() + { + Debug.Assert(_callback == _callbackCompleted); + + var exception = _exception; + var status = _status; + + // Reset the awaitable state + _exception = null; + _status = 0; + _callback = null; + + return new UvWriteResult(status, exception); + } + + public void OnCompleted(Action continuation) + { + // There should never be a race between IsCompleted and OnCompleted since both operations + // should always be on the libuv thread + + if (ReferenceEquals(_callback, _callbackCompleted) || + ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) + { + Debug.Fail($"{typeof(LibuvAwaitable)}.{nameof(OnCompleted)} raced with {nameof(IsCompleted)}, running callback inline."); + + // Just run it inline + continuation(); + } + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + } + + public struct UvWriteResult + { + public int Status { get; } + public UvException Error { get; } + + public UvWriteResult(int status, UvException error) + { + Status = status; + Error = error; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConnection.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConnection.cs new file mode 100644 index 0000000000..f0aa50df8b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConnection.cs @@ -0,0 +1,253 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public partial class LibuvConnection : TransportConnection + { + private static readonly int MinAllocBufferSize = KestrelMemoryPool.MinimumSegmentSize / 2; + + private static readonly Action _readCallback = + (handle, status, state) => ReadCallback(handle, status, state); + + private static readonly Func _allocCallback = + (handle, suggestedsize, state) => AllocCallback(handle, suggestedsize, state); + + private readonly UvStreamHandle _socket; + private readonly CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); + + private volatile ConnectionAbortedException _abortReason; + + private MemoryHandle _bufferHandle; + + public LibuvConnection(UvStreamHandle socket, ILibuvTrace log, LibuvThread thread, IPEndPoint remoteEndPoint, IPEndPoint localEndPoint) + { + _socket = socket; + + RemoteAddress = remoteEndPoint?.Address; + RemotePort = remoteEndPoint?.Port ?? 0; + + LocalAddress = localEndPoint?.Address; + LocalPort = localEndPoint?.Port ?? 0; + + ConnectionClosed = _connectionClosedTokenSource.Token; + Log = log; + Thread = thread; + } + + public LibuvOutputConsumer OutputConsumer { get; set; } + private ILibuvTrace Log { get; } + private LibuvThread Thread { get; } + public override MemoryPool MemoryPool => Thread.MemoryPool; + public override PipeScheduler InputWriterScheduler => Thread; + public override PipeScheduler OutputReaderScheduler => Thread; + + public override long TotalBytesWritten => OutputConsumer?.TotalBytesWritten ?? 0; + + public async Task Start() + { + try + { + OutputConsumer = new LibuvOutputConsumer(Output, Thread, _socket, ConnectionId, Log); + + StartReading(); + + Exception inputError = null; + Exception outputError = null; + + try + { + // This *must* happen after socket.ReadStart + // The socket output consumer is the only thing that can close the connection. If the + // output pipe is already closed by the time we start then it's fine since, it'll close gracefully afterwards. + await OutputConsumer.WriteOutputAsync(); + } + catch (UvException ex) + { + // The connection reset/error has already been logged by LibuvOutputConsumer + if (ex.StatusCode == LibuvConstants.ECANCELED) + { + // Connection was aborted. + } + else if (LibuvConstants.IsConnectionReset(ex.StatusCode)) + { + // Don't cause writes to throw for connection resets. + inputError = new ConnectionResetException(ex.Message, ex); + } + else + { + inputError = ex; + outputError = ex; + } + } + finally + { + // Now, complete the input so that no more reads can happen + Input.Complete(inputError ?? _abortReason ?? new ConnectionAbortedException()); + Output.Complete(outputError); + + // Make sure it isn't possible for a paused read to resume reading after calling uv_close + // on the stream handle + Input.CancelPendingFlush(); + + // Send a FIN + Log.ConnectionWriteFin(ConnectionId); + + // We're done with the socket now + _socket.Dispose(); + ThreadPool.QueueUserWorkItem(state => ((LibuvConnection)state).CancelConnectionClosedToken(), this); + } + } + catch (Exception e) + { + Log.LogCritical(0, e, $"{nameof(LibuvConnection)}.{nameof(Start)}() {ConnectionId}"); + } + } + + public override void Abort(ConnectionAbortedException abortReason) + { + _abortReason = abortReason; + Output.CancelPendingRead(); + + // This cancels any pending I/O. + Thread.Post(s => s.Dispose(), _socket); + } + + // Called on Libuv thread + private static LibuvFunctions.uv_buf_t AllocCallback(UvStreamHandle handle, int suggestedSize, object state) + { + return ((LibuvConnection)state).OnAlloc(handle, suggestedSize); + } + + private unsafe LibuvFunctions.uv_buf_t OnAlloc(UvStreamHandle handle, int suggestedSize) + { + var currentWritableBuffer = Input.GetMemory(MinAllocBufferSize); + _bufferHandle = currentWritableBuffer.Pin(); + + return handle.Libuv.buf_init((IntPtr)_bufferHandle.Pointer, currentWritableBuffer.Length); + } + + private static void ReadCallback(UvStreamHandle handle, int status, object state) + { + ((LibuvConnection)state).OnRead(handle, status); + } + + private void OnRead(UvStreamHandle handle, int status) + { + // Cleanup state from last OnAlloc. This is safe even if OnAlloc wasn't called. + _bufferHandle.Dispose(); + if (status == 0) + { + // EAGAIN/EWOULDBLOCK so just return the buffer. + // http://docs.libuv.org/en/v1.x/stream.html#c.uv_read_cb + } + else if (status > 0) + { + Log.ConnectionRead(ConnectionId, status); + + Input.Advance(status); + var flushTask = Input.FlushAsync(); + + if (!flushTask.IsCompleted) + { + // We wrote too many bytes to the reader, so pause reading and resume when + // we hit the low water mark. + _ = ApplyBackpressureAsync(flushTask); + } + } + else + { + // Given a negative status, it's possible that OnAlloc wasn't called. + _socket.ReadStop(); + + Exception error = null; + + if (status == LibuvConstants.EOF) + { + Log.ConnectionReadFin(ConnectionId); + } + else + { + handle.Libuv.Check(status, out var uvError); + error = LogAndWrapReadError(uvError); + } + + // Complete after aborting the connection + Input.Complete(error); + } + } + + private async Task ApplyBackpressureAsync(ValueTask flushTask) + { + Log.ConnectionPause(ConnectionId); + _socket.ReadStop(); + + var result = await flushTask; + + // If the reader isn't complete or cancelled then resume reading + if (!result.IsCompleted && !result.IsCanceled) + { + Log.ConnectionResume(ConnectionId); + StartReading(); + } + } + + private void StartReading() + { + try + { + _socket.ReadStart(_allocCallback, _readCallback, this); + } + catch (UvException ex) + { + // ReadStart() can throw a UvException in some cases (e.g. socket is no longer connected). + // This should be treated the same as OnRead() seeing a negative status. + Input.Complete(LogAndWrapReadError(ex)); + } + } + + private Exception LogAndWrapReadError(UvException uvError) + { + if (uvError.StatusCode == LibuvConstants.ECANCELED) + { + // The operation was canceled by the server not the client. No need for additional logs. + return new ConnectionAbortedException(uvError.Message, uvError); + } + else if (LibuvConstants.IsConnectionReset(uvError.StatusCode)) + { + // Log connection resets at a lower (Debug) level. + Log.ConnectionReset(ConnectionId); + return new ConnectionResetException(uvError.Message, uvError); + } + else + { + Log.ConnectionError(ConnectionId, uvError); + return new IOException(uvError.Message, uvError); + } + } + + private void CancelConnectionClosedToken() + { + try + { + _connectionClosedTokenSource.Cancel(); + } + catch (Exception ex) + { + Log.LogError(0, ex, $"Unexpected exception in {nameof(LibuvConnection)}.{nameof(CancelConnectionClosedToken)}."); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConstants.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConstants.cs new file mode 100644 index 0000000000..d1657b63b8 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvConstants.cs @@ -0,0 +1,143 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + internal static class LibuvConstants + { + public const int ListenBacklog = 128; + + public const int EOF = -4095; + public static readonly int? ECONNRESET = GetECONNRESET(); + public static readonly int? EADDRINUSE = GetEADDRINUSE(); + public static readonly int? ENOTSUP = GetENOTSUP(); + public static readonly int? EPIPE = GetEPIPE(); + public static readonly int? ECANCELED = GetECANCELED(); + public static readonly int? ENOTCONN = GetENOTCONN(); + public static readonly int? EINVAL = GetEINVAL(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsConnectionReset(int errno) + { + return errno == ECONNRESET || errno == EPIPE || errno == ENOTCONN | errno == EINVAL; + } + + private static int? GetECONNRESET() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4077; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -104; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -54; + } + return null; + } + + private static int? GetEPIPE() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4047; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -32; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -32; + } + return null; + } + + private static int? GetENOTCONN() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4053; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -107; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -57; + } + return null; + } + + private static int? GetEINVAL() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4071; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -22; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -22; + } + return null; + } + + private static int? GetEADDRINUSE() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4091; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -98; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -48; + } + return null; + } + + private static int? GetENOTSUP() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -95; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -45; + } + return null; + } + + private static int? GetECANCELED() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return -4081; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return -125; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return -89; + } + return null; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvOutputConsumer.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvOutputConsumer.cs new file mode 100644 index 0000000000..6049245537 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvOutputConsumer.cs @@ -0,0 +1,129 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvOutputConsumer + { + private readonly LibuvThread _thread; + private readonly UvStreamHandle _socket; + private readonly string _connectionId; + private readonly ILibuvTrace _log; + private readonly PipeReader _pipe; + + private long _totalBytesWritten; + + public LibuvOutputConsumer( + PipeReader pipe, + LibuvThread thread, + UvStreamHandle socket, + string connectionId, + ILibuvTrace log) + { + _pipe = pipe; + _thread = thread; + _socket = socket; + _connectionId = connectionId; + _log = log; + } + + public long TotalBytesWritten => Interlocked.Read(ref _totalBytesWritten); + + public async Task WriteOutputAsync() + { + var pool = _thread.WriteReqPool; + + while (true) + { + var result = await _pipe.ReadAsync(); + + var buffer = result.Buffer; + var consumed = buffer.End; + + try + { + if (result.IsCanceled) + { + break; + } + + if (!buffer.IsEmpty) + { + var writeReq = pool.Allocate(); + + try + { + if (_socket.IsClosed) + { + break; + } + + var writeResult = await writeReq.WriteAsync(_socket, buffer); + + // This is not interlocked because there could be a concurrent writer. + // Instead it's to prevent read tearing on 32-bit systems. + Interlocked.Add(ref _totalBytesWritten, buffer.Length); + + LogWriteInfo(writeResult.Status, writeResult.Error); + + if (writeResult.Error != null) + { + consumed = buffer.Start; + throw writeResult.Error; + } + } + finally + { + // Make sure we return the writeReq to the pool + pool.Return(writeReq); + + // Null out writeReq so it doesn't get caught by CheckUvReqLeaks. + // It is rooted by a TestSink scope through Pipe continuations in + // ResponseTests.HttpsConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate + writeReq = null; + } + } + + if (result.IsCompleted) + { + break; + } + } + finally + { + _pipe.AdvanceTo(consumed); + } + } + } + + private void LogWriteInfo(int status, Exception error) + { + if (error == null) + { + _log.ConnectionWriteCallback(_connectionId, status); + } + else + { + // Log connection resets at a lower (Debug) level. + if (status == LibuvConstants.ECANCELED) + { + // Connection was aborted. + } + else if (LibuvConstants.IsConnectionReset(status)) + { + _log.ConnectionReset(_connectionId); + } + else + { + _log.ConnectionError(_connectionId, error); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvThread.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvThread.cs new file mode 100644 index 0000000000..f8145ae22d --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvThread.cs @@ -0,0 +1,452 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvThread : PipeScheduler + { + // maximum times the work queues swapped and are processed in a single pass + // as completing a task may immediately have write data to put on the network + // otherwise it needs to wait till the next pass of the libuv loop + private readonly int _maxLoops = 8; + + private readonly LibuvTransport _transport; + private readonly IApplicationLifetime _appLifetime; + private readonly Thread _thread; + private readonly TaskCompletionSource _threadTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly UvLoopHandle _loop; + private readonly UvAsyncHandle _post; + private Queue _workAdding = new Queue(1024); + private Queue _workRunning = new Queue(1024); + private Queue _closeHandleAdding = new Queue(256); + private Queue _closeHandleRunning = new Queue(256); + private readonly object _workSync = new object(); + private readonly object _closeHandleSync = new object(); + private readonly object _startSync = new object(); + private bool _stopImmediate = false; + private bool _initCompleted = false; + private ExceptionDispatchInfo _closeError; + private readonly ILibuvTrace _log; + + public LibuvThread(LibuvTransport transport) + { + _transport = transport; + _appLifetime = transport.AppLifetime; + _log = transport.Log; + _loop = new UvLoopHandle(_log); + _post = new UvAsyncHandle(_log); + + _thread = new Thread(ThreadStart); +#if !INNER_LOOP + _thread.Name = nameof(LibuvThread); +#endif + +#if !DEBUG + // Mark the thread as being as unimportant to keeping the process alive. + // Don't do this for debug builds, so we know if the thread isn't terminating. + _thread.IsBackground = true; +#endif + QueueCloseHandle = PostCloseHandle; + QueueCloseAsyncHandle = EnqueueCloseHandle; + MemoryPool = KestrelMemoryPool.Create(); + WriteReqPool = new WriteReqPool(this, _log); + } + + // For testing + public LibuvThread(LibuvTransport transport, int maxLoops) + : this(transport) + { + _maxLoops = maxLoops; + } + + public UvLoopHandle Loop { get { return _loop; } } + + public MemoryPool MemoryPool { get; } + + public WriteReqPool WriteReqPool { get; } + +#if DEBUG + public List Requests { get; } = new List(); +#endif + + public ExceptionDispatchInfo FatalError { get { return _closeError; } } + + public Action, IntPtr> QueueCloseHandle { get; } + + private Action, IntPtr> QueueCloseAsyncHandle { get; } + + public Task StartAsync() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _thread.Start(tcs); + return tcs.Task; + } + + public async Task StopAsync(TimeSpan timeout) + { + lock (_startSync) + { + if (!_initCompleted) + { + return; + } + } + + Debug.Assert(!_threadTcs.Task.IsCompleted, "The loop thread was completed before calling uv_unref on the post handle."); + + var stepTimeout = TimeSpan.FromTicks(timeout.Ticks / 3); + + try + { + Post(t => t.AllowStop()); + if (!await WaitAsync(_threadTcs.Task, stepTimeout).ConfigureAwait(false)) + { + Post(t => t.OnStopRude()); + if (!await WaitAsync(_threadTcs.Task, stepTimeout).ConfigureAwait(false)) + { + Post(t => t.OnStopImmediate()); + if (!await WaitAsync(_threadTcs.Task, stepTimeout).ConfigureAwait(false)) + { + _log.LogCritical($"{nameof(LibuvThread)}.{nameof(StopAsync)} failed to terminate libuv thread."); + } + } + } + } + catch (ObjectDisposedException) + { + if (!await WaitAsync(_threadTcs.Task, stepTimeout).ConfigureAwait(false)) + { + _log.LogCritical($"{nameof(LibuvThread)}.{nameof(StopAsync)} failed to terminate libuv thread."); + } + } + + _closeError?.Throw(); + } + +#if DEBUG && !INNER_LOOP + private void CheckUvReqLeaks() + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + // Detect leaks in UvRequest objects + foreach (var request in Requests) + { + Debug.Assert(request.Target == null, $"{request.Target?.GetType()} object is still alive."); + } + } +#endif + + private void AllowStop() + { + _post.Unreference(); + } + + private void OnStopRude() + { + Walk(ptr => + { + var handle = UvMemory.FromIntPtr(ptr); + if (handle != _post) + { + // handle can be null because UvMemory.FromIntPtr looks up a weak reference + handle?.Dispose(); + } + }); + } + + private void OnStopImmediate() + { + _stopImmediate = true; + _loop.Stop(); + } + + public void Post(Action callback, T state) + { + // Handle is closed to don't bother scheduling anything + if (_post.IsClosed) + { + return; + } + + var work = new Work + { + CallbackAdapter = CallbackAdapter.PostCallbackAdapter, + Callback = callback, + // TODO: This boxes + State = state + }; + + lock (_workSync) + { + _workAdding.Enqueue(work); + } + + try + { + _post.Send(); + } + catch (ObjectDisposedException) + { + // There's an inherent race here where we're in the middle of shutdown + } + } + + private void Post(Action callback) + { + Post(callback, this); + } + + public Task PostAsync(Action callback, T state) + { + // Handle is closed to don't bother scheduling anything + if (_post.IsClosed) + { + return Task.CompletedTask; + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var work = new Work + { + CallbackAdapter = CallbackAdapter.PostAsyncCallbackAdapter, + Callback = callback, + State = state, + Completion = tcs + }; + + lock (_workSync) + { + _workAdding.Enqueue(work); + } + + try + { + _post.Send(); + } + catch (ObjectDisposedException) + { + // There's an inherent race here where we're in the middle of shutdown + } + return tcs.Task; + } + + public void Walk(Action callback) + { + Walk((ptr, arg) => callback(ptr), IntPtr.Zero); + } + + private void Walk(LibuvFunctions.uv_walk_cb callback, IntPtr arg) + { + _transport.Libuv.walk( + _loop, + callback, + arg + ); + } + + private void PostCloseHandle(Action callback, IntPtr handle) + { + EnqueueCloseHandle(callback, handle); + _post.Send(); + } + + private void EnqueueCloseHandle(Action callback, IntPtr handle) + { + var closeHandle = new CloseHandle { Callback = callback, Handle = handle }; + lock (_closeHandleSync) + { + _closeHandleAdding.Enqueue(closeHandle); + } + } + + private void ThreadStart(object parameter) + { + lock (_startSync) + { + var tcs = (TaskCompletionSource)parameter; + try + { + _loop.Init(_transport.Libuv); + _post.Init(_loop, OnPost, EnqueueCloseHandle); + _initCompleted = true; + tcs.SetResult(0); + } + catch (Exception ex) + { + tcs.SetException(ex); + return; + } + } + + try + { + _loop.Run(); + if (_stopImmediate) + { + // thread-abort form of exit, resources will be leaked + return; + } + + // run the loop one more time to delete the open handles + _post.Reference(); + _post.Dispose(); + + // We need this walk because we call ReadStop on accepted connections when there's back pressure + // Calling ReadStop makes the handle as in-active which means the loop can + // end while there's still valid handles around. This makes loop.Dispose throw + // with an EBUSY. To avoid that, we walk all of the handles and dispose them. + Walk(ptr => + { + var handle = UvMemory.FromIntPtr(ptr); + // handle can be null because UvMemory.FromIntPtr looks up a weak reference + handle?.Dispose(); + }); + + // Ensure the Dispose operations complete in the event loop. + _loop.Run(); + + _loop.Dispose(); + } + catch (Exception ex) + { + _closeError = ExceptionDispatchInfo.Capture(ex); + // Request shutdown so we can rethrow this exception + // in Stop which should be observable. + _appLifetime.StopApplication(); + } + finally + { + MemoryPool.Dispose(); + WriteReqPool.Dispose(); + _threadTcs.SetResult(null); + +#if DEBUG && !INNER_LOOP + // Check for handle leaks after disposing everything + CheckUvReqLeaks(); +#endif + } + } + + private void OnPost() + { + var loopsRemaining = _maxLoops; + bool wasWork; + do + { + wasWork = DoPostWork(); + wasWork = DoPostCloseHandle() || wasWork; + loopsRemaining--; + } while (wasWork && loopsRemaining > 0); + } + + private bool DoPostWork() + { + Queue queue; + lock (_workSync) + { + queue = _workAdding; + _workAdding = _workRunning; + _workRunning = queue; + } + + bool wasWork = queue.Count > 0; + + while (queue.Count != 0) + { + var work = queue.Dequeue(); + try + { + work.CallbackAdapter(work.Callback, work.State); + work.Completion?.TrySetResult(null); + } + catch (Exception ex) + { + if (work.Completion != null) + { + work.Completion.TrySetException(ex); + } + else + { + _log.LogError(0, ex, $"{nameof(LibuvThread)}.{nameof(DoPostWork)}"); + throw; + } + } + } + + return wasWork; + } + + private bool DoPostCloseHandle() + { + Queue queue; + lock (_closeHandleSync) + { + queue = _closeHandleAdding; + _closeHandleAdding = _closeHandleRunning; + _closeHandleRunning = queue; + } + + bool wasWork = queue.Count > 0; + + while (queue.Count != 0) + { + var closeHandle = queue.Dequeue(); + try + { + closeHandle.Callback(closeHandle.Handle); + } + catch (Exception ex) + { + _log.LogError(0, ex, $"{nameof(LibuvThread)}.{nameof(DoPostCloseHandle)}"); + throw; + } + } + + return wasWork; + } + + private static async Task WaitAsync(Task task, TimeSpan timeout) + { + return await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false) == task; + } + + public override void Schedule(Action action, object state) + { + Post(action, state); + } + + private struct Work + { + public Action CallbackAdapter; + public object Callback; + public object State; + public TaskCompletionSource Completion; + } + + private struct CloseHandle + { + public Action Callback; + public IntPtr Handle; + } + + private class CallbackAdapter + { + public static readonly Action PostCallbackAdapter = (callback, state) => ((Action)callback).Invoke((T)state); + public static readonly Action PostAsyncCallbackAdapter = (callback, state) => ((Action)callback).Invoke((T)state); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTrace.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTrace.cs new file mode 100644 index 0000000000..96f3338821 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTrace.cs @@ -0,0 +1,105 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvTrace : ILibuvTrace + { + // ConnectionRead: Reserved: 3 + + private static readonly Action _connectionPause = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(ConnectionPause)), @"Connection id ""{ConnectionId}"" paused."); + + private static readonly Action _connectionResume = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ConnectionResume)), @"Connection id ""{ConnectionId}"" resumed."); + + private static readonly Action _connectionReadFin = + LoggerMessage.Define(LogLevel.Debug, new EventId(6, nameof(ConnectionReadFin)), @"Connection id ""{ConnectionId}"" received FIN."); + + private static readonly Action _connectionWriteFin = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, nameof(ConnectionWriteFin)), @"Connection id ""{ConnectionId}"" sending FIN."); + + private static readonly Action _connectionWroteFin = + LoggerMessage.Define(LogLevel.Debug, new EventId(8, nameof(ConnectionWroteFin)), @"Connection id ""{ConnectionId}"" sent FIN with status ""{Status}""."); + + // ConnectionWrite: Reserved: 11 + + // ConnectionWriteCallback: Reserved: 12 + + private static readonly Action _connectionError = + LoggerMessage.Define(LogLevel.Information, 14, @"Connection id ""{ConnectionId}"" communication error."); + + private static readonly Action _connectionReset = + LoggerMessage.Define(LogLevel.Debug, 19, @"Connection id ""{ConnectionId}"" reset."); + + private readonly ILogger _logger; + + public LibuvTrace(ILogger logger) + { + _logger = logger; + } + + public void ConnectionRead(string connectionId, int count) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 3 + } + + public void ConnectionReadFin(string connectionId) + { + _connectionReadFin(_logger, connectionId, null); + } + + public void ConnectionWriteFin(string connectionId) + { + _connectionWriteFin(_logger, connectionId, null); + } + + public void ConnectionWroteFin(string connectionId, int status) + { + _connectionWroteFin(_logger, connectionId, status, null); + } + + public void ConnectionWrite(string connectionId, int count) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 11 + } + + public void ConnectionWriteCallback(string connectionId, int status) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 12 + } + + public void ConnectionError(string connectionId, Exception ex) + { + _connectionError(_logger, connectionId, ex); + } + + public void ConnectionReset(string connectionId) + { + _connectionReset(_logger, connectionId, null); + } + + public void ConnectionPause(string connectionId) + { + _connectionPause(_logger, connectionId, null); + } + + public void ConnectionResume(string connectionId) + { + _connectionResume(_logger, connectionId, null); + } + + public IDisposable BeginScope(TState state) => _logger.BeginScope(state); + + public bool IsEnabled(LogLevel logLevel) => _logger.IsEnabled(logLevel); + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + => _logger.Log(logLevel, eventId, state, exception, formatter); + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransport.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransport.cs new file mode 100644 index 0000000000..293745f38b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransport.cs @@ -0,0 +1,139 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvTransport : ITransport + { + private readonly IEndPointInformation _endPointInformation; + + private readonly List _listeners = new List(); + + public LibuvTransport(LibuvTransportContext context, IEndPointInformation endPointInformation) + : this(new LibuvFunctions(), context, endPointInformation) + { } + + // For testing + public LibuvTransport(LibuvFunctions uv, LibuvTransportContext context, IEndPointInformation endPointInformation) + { + Libuv = uv; + TransportContext = context; + + _endPointInformation = endPointInformation; + } + + public LibuvFunctions Libuv { get; } + public LibuvTransportContext TransportContext { get; } + public List Threads { get; } = new List(); + + public IApplicationLifetime AppLifetime => TransportContext.AppLifetime; + public ILibuvTrace Log => TransportContext.Log; + public LibuvTransportOptions TransportOptions => TransportContext.Options; + + public async Task StopAsync() + { + try + { + await Task.WhenAll(Threads.Select(thread => thread.StopAsync(TimeSpan.FromSeconds(5))).ToArray()) + .ConfigureAwait(false); + } + catch (AggregateException aggEx) + { + // An uncaught exception was likely thrown from the libuv event loop. + // The original error that crashed one loop may have caused secondary errors in others. + // Make sure that the stack trace of the original error is logged. + foreach (var ex in aggEx.InnerExceptions) + { + Log.LogCritical("Failed to gracefully close Kestrel.", ex); + } + + throw; + } + + Threads.Clear(); +#if DEBUG && !INNER_LOOP + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); +#endif + } + + public async Task BindAsync() + { + // TODO: Move thread management to LibuvTransportFactory + // TODO: Split endpoint management from thread management + for (var index = 0; index < TransportOptions.ThreadCount; index++) + { + Threads.Add(new LibuvThread(this)); + } + + foreach (var thread in Threads) + { + await thread.StartAsync().ConfigureAwait(false); + } + + try + { + if (TransportOptions.ThreadCount == 1) + { + var listener = new Listener(TransportContext); + _listeners.Add(listener); + await listener.StartAsync(_endPointInformation, Threads[0]).ConfigureAwait(false); + } + else + { + var pipeName = (Libuv.IsWindows ? @"\\.\pipe\kestrel_" : "/tmp/kestrel_") + Guid.NewGuid().ToString("n"); + var pipeMessage = Guid.NewGuid().ToByteArray(); + + var listenerPrimary = new ListenerPrimary(TransportContext); + _listeners.Add(listenerPrimary); + await listenerPrimary.StartAsync(pipeName, pipeMessage, _endPointInformation, Threads[0]).ConfigureAwait(false); + + foreach (var thread in Threads.Skip(1)) + { + var listenerSecondary = new ListenerSecondary(TransportContext); + _listeners.Add(listenerSecondary); + await listenerSecondary.StartAsync(pipeName, pipeMessage, _endPointInformation, thread).ConfigureAwait(false); + } + } + } + catch (UvException ex) when (ex.StatusCode == LibuvConstants.EADDRINUSE) + { + await UnbindAsync().ConfigureAwait(false); + throw new AddressInUseException(ex.Message, ex); + } + catch + { + await UnbindAsync().ConfigureAwait(false); + throw; + } + } + + public async Task UnbindAsync() + { + var disposeTasks = _listeners.Select(listener => listener.DisposeAsync()).ToArray(); + + if (!await WaitAsync(Task.WhenAll(disposeTasks), TimeSpan.FromSeconds(5)).ConfigureAwait(false)) + { + Log.LogError(0, null, "Disposing listeners failed"); + } + + _listeners.Clear(); + } + + private static async Task WaitAsync(Task task, TimeSpan timeout) + { + return await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false) == task; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportContext.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportContext.cs new file mode 100644 index 0000000000..37074dc968 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportContext.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvTransportContext + { + public LibuvTransportOptions Options { get; set; } + + public IApplicationLifetime AppLifetime { get; set; } + + public ILibuvTrace Log { get; set; } + + public IConnectionDispatcher ConnectionDispatcher { get; set; } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportFactory.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportFactory.cs new file mode 100644 index 0000000000..c2ff9e8a56 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/LibuvTransportFactory.cs @@ -0,0 +1,77 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class LibuvTransportFactory : ITransportFactory + { + private readonly LibuvTransportContext _baseTransportContext; + + public LibuvTransportFactory( + IOptions options, + IApplicationLifetime applicationLifetime, + ILoggerFactory loggerFactory) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + if (applicationLifetime == null) + { + throw new ArgumentNullException(nameof(applicationLifetime)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + + var logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv"); + var trace = new LibuvTrace(logger); + + var threadCount = options.Value.ThreadCount; + + if (threadCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(threadCount), + threadCount, + "ThreadCount must be positive."); + } + + if (!LibuvConstants.ECONNRESET.HasValue) + { + trace.LogWarning("Unable to determine ECONNRESET value on this platform."); + } + + if (!LibuvConstants.EADDRINUSE.HasValue) + { + trace.LogWarning("Unable to determine EADDRINUSE value on this platform."); + } + + _baseTransportContext = new LibuvTransportContext + { + Options = options.Value, + AppLifetime = applicationLifetime, + Log = trace, + }; + } + + public ITransport Create(IEndPointInformation endPointInformation, IConnectionDispatcher dispatcher) + { + var transportContext = new LibuvTransportContext + { + Options = _baseTransportContext.Options, + AppLifetime = _baseTransportContext.AppLifetime, + Log = _baseTransportContext.Log, + ConnectionDispatcher = dispatcher + }; + + return new LibuvTransport(transportContext, endPointInformation); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Listener.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Listener.cs new file mode 100644 index 0000000000..f4b3520854 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Listener.cs @@ -0,0 +1,212 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + /// + /// Base class for listeners in Kestrel. Listens for incoming connections + /// + public class Listener : ListenerContext, IAsyncDisposable + { + private bool _closed; + + public Listener(LibuvTransportContext transportContext) : base(transportContext) + { + } + + protected UvStreamHandle ListenSocket { get; private set; } + + public ILibuvTrace Log => TransportContext.Log; + + public Task StartAsync( + IEndPointInformation endPointInformation, + LibuvThread thread) + { + EndPointInformation = endPointInformation; + Thread = thread; + + return Thread.PostAsync(listener => + { + listener.ListenSocket = listener.CreateListenSocket(); + listener.ListenSocket.Listen(LibuvConstants.ListenBacklog, ConnectionCallback, listener); + }, this); + } + + /// + /// Creates the socket used to listen for incoming connections + /// + private UvStreamHandle CreateListenSocket() + { + switch (EndPointInformation.Type) + { + case ListenType.IPEndPoint: + return ListenTcp(useFileHandle: false); + case ListenType.SocketPath: + return ListenPipe(useFileHandle: false); + case ListenType.FileHandle: + return ListenHandle(); + default: + throw new NotSupportedException(); + } + } + + private UvTcpHandle ListenTcp(bool useFileHandle) + { + var socket = new UvTcpHandle(Log); + + try + { + socket.Init(Thread.Loop, Thread.QueueCloseHandle); + socket.NoDelay(EndPointInformation.NoDelay); + + if (!useFileHandle) + { + socket.Bind(EndPointInformation.IPEndPoint); + + // If requested port was "0", replace with assigned dynamic port. + EndPointInformation.IPEndPoint = socket.GetSockIPEndPoint(); + } + else + { + socket.Open((IntPtr)EndPointInformation.FileHandle); + } + } + catch + { + socket.Dispose(); + throw; + } + + return socket; + } + + private UvPipeHandle ListenPipe(bool useFileHandle) + { + var pipe = new UvPipeHandle(Log); + + try + { + pipe.Init(Thread.Loop, Thread.QueueCloseHandle, false); + + if (!useFileHandle) + { + pipe.Bind(EndPointInformation.SocketPath); + } + else + { + pipe.Open((IntPtr)EndPointInformation.FileHandle); + } + } + catch + { + pipe.Dispose(); + throw; + } + + return pipe; + } + + private UvStreamHandle ListenHandle() + { + switch (EndPointInformation.HandleType) + { + case FileHandleType.Auto: + break; + case FileHandleType.Tcp: + return ListenTcp(useFileHandle: true); + case FileHandleType.Pipe: + return ListenPipe(useFileHandle: true); + default: + throw new NotSupportedException(); + } + + UvStreamHandle handle; + try + { + handle = ListenTcp(useFileHandle: true); + EndPointInformation.HandleType = FileHandleType.Tcp; + return handle; + } + catch (UvException exception) when (exception.StatusCode == LibuvConstants.ENOTSUP) + { + Log.LogDebug(0, exception, "Listener.ListenHandle"); + } + + handle = ListenPipe(useFileHandle: true); + EndPointInformation.HandleType = FileHandleType.Pipe; + return handle; + } + + private static void ConnectionCallback(UvStreamHandle stream, int status, UvException error, object state) + { + var listener = (Listener)state; + + if (error != null) + { + listener.Log.LogError(0, error, "Listener.ConnectionCallback"); + } + else if (!listener._closed) + { + listener.OnConnection(stream, status); + } + } + + /// + /// Handles an incoming connection + /// + /// Socket being used to listen on + /// Connection status + private void OnConnection(UvStreamHandle listenSocket, int status) + { + UvStreamHandle acceptSocket = null; + + try + { + acceptSocket = CreateAcceptSocket(); + listenSocket.Accept(acceptSocket); + DispatchConnection(acceptSocket); + } + catch (UvException ex) when (LibuvConstants.IsConnectionReset(ex.StatusCode)) + { + Log.ConnectionReset("(null)"); + acceptSocket?.Dispose(); + } + catch (UvException ex) + { + Log.LogError(0, ex, "Listener.OnConnection"); + acceptSocket?.Dispose(); + } + } + + protected virtual void DispatchConnection(UvStreamHandle socket) + { + HandleConnectionAsync(socket); + } + + public virtual async Task DisposeAsync() + { + // Ensure the event loop is still running. + // If the event loop isn't running and we try to wait on this Post + // to complete, then LibuvTransport will never be disposed and + // the exception that stopped the event loop will never be surfaced. + if (Thread.FatalError == null && ListenSocket != null) + { + await Thread.PostAsync(listener => + { + listener.ListenSocket.Dispose(); + + listener._closed = true; + + }, this).ConfigureAwait(false); + } + + ListenSocket = null; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerContext.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerContext.cs new file mode 100644 index 0000000000..8399c90dfa --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerContext.cs @@ -0,0 +1,126 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class ListenerContext + { + public ListenerContext(LibuvTransportContext transportContext) + { + TransportContext = transportContext; + } + + public LibuvTransportContext TransportContext { get; set; } + + public IEndPointInformation EndPointInformation { get; set; } + + public LibuvThread Thread { get; set; } + + /// + /// Creates a socket which can be used to accept an incoming connection. + /// + protected UvStreamHandle CreateAcceptSocket() + { + switch (EndPointInformation.Type) + { + case ListenType.IPEndPoint: + return AcceptTcp(); + case ListenType.SocketPath: + return AcceptPipe(); + case ListenType.FileHandle: + return AcceptHandle(); + default: + throw new InvalidOperationException(); + } + } + + protected void HandleConnectionAsync(UvStreamHandle socket) + { + try + { + IPEndPoint remoteEndPoint = null; + IPEndPoint localEndPoint = null; + + if (socket is UvTcpHandle tcpHandle) + { + try + { + remoteEndPoint = tcpHandle.GetPeerIPEndPoint(); + localEndPoint = tcpHandle.GetSockIPEndPoint(); + } + catch (UvException ex) when (LibuvConstants.IsConnectionReset(ex.StatusCode)) + { + TransportContext.Log.ConnectionReset("(null)"); + socket.Dispose(); + return; + } + } + + var connection = new LibuvConnection(socket, TransportContext.Log, Thread, remoteEndPoint, localEndPoint); + TransportContext.ConnectionDispatcher.OnConnection(connection); + _ = connection.Start(); + } + catch (Exception ex) + { + TransportContext.Log.LogCritical(ex, $"Unexpected exception in {nameof(ListenerContext)}.{nameof(HandleConnectionAsync)}."); + } + } + + private UvTcpHandle AcceptTcp() + { + var socket = new UvTcpHandle(TransportContext.Log); + + try + { + socket.Init(Thread.Loop, Thread.QueueCloseHandle); + socket.NoDelay(EndPointInformation.NoDelay); + } + catch + { + socket.Dispose(); + throw; + } + + return socket; + } + + private UvPipeHandle AcceptPipe() + { + var pipe = new UvPipeHandle(TransportContext.Log); + + try + { + pipe.Init(Thread.Loop, Thread.QueueCloseHandle); + } + catch + { + pipe.Dispose(); + throw; + } + + return pipe; + } + + private UvStreamHandle AcceptHandle() + { + switch (EndPointInformation.HandleType) + { + case FileHandleType.Auto: + throw new InvalidOperationException("Cannot accept on a non-specific file handle, listen should be performed first."); + case FileHandleType.Tcp: + return AcceptTcp(); + case FileHandleType.Pipe: + return AcceptPipe(); + default: + throw new NotSupportedException(); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerPrimary.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerPrimary.cs new file mode 100644 index 0000000000..1218136148 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerPrimary.cs @@ -0,0 +1,270 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + /// + /// A primary listener waits for incoming connections on a specified socket. Incoming + /// connections may be passed to a secondary listener to handle. + /// + public class ListenerPrimary : Listener + { + // The list of pipes that can be dispatched to (where we've confirmed the _pipeMessage) + private readonly List _dispatchPipes = new List(); + // The list of pipes we've created but may not be part of _dispatchPipes + private readonly List _createdPipes = new List(); + private int _dispatchIndex; + private string _pipeName; + private byte[] _pipeMessage; + private IntPtr _fileCompletionInfoPtr; + private bool _tryDetachFromIOCP = PlatformApis.IsWindows; + + // this message is passed to write2 because it must be non-zero-length, + // but it has no other functional significance + private readonly ArraySegment> _dummyMessage = new ArraySegment>(new[] { new ArraySegment(new byte[] { 1, 2, 3, 4 }) }); + + public ListenerPrimary(LibuvTransportContext transportContext) : base(transportContext) + { + } + + /// + /// For testing purposes. + /// + public int UvPipeCount => _dispatchPipes.Count; + + private UvPipeHandle ListenPipe { get; set; } + + public async Task StartAsync( + string pipeName, + byte[] pipeMessage, + IEndPointInformation endPointInformation, + LibuvThread thread) + { + _pipeName = pipeName; + _pipeMessage = pipeMessage; + + if (_fileCompletionInfoPtr == IntPtr.Zero) + { + var fileCompletionInfo = new FILE_COMPLETION_INFORMATION() { Key = IntPtr.Zero, Port = IntPtr.Zero }; + _fileCompletionInfoPtr = Marshal.AllocHGlobal(Marshal.SizeOf(fileCompletionInfo)); + Marshal.StructureToPtr(fileCompletionInfo, _fileCompletionInfoPtr, false); + } + + await StartAsync(endPointInformation, thread).ConfigureAwait(false); + + await Thread.PostAsync(listener => listener.PostCallback(), this).ConfigureAwait(false); + } + + private void PostCallback() + { + ListenPipe = new UvPipeHandle(Log); + ListenPipe.Init(Thread.Loop, Thread.QueueCloseHandle, false); + ListenPipe.Bind(_pipeName); + ListenPipe.Listen(LibuvConstants.ListenBacklog, + (pipe, status, error, state) => ((ListenerPrimary)state).OnListenPipe(pipe, status, error), this); + } + + private void OnListenPipe(UvStreamHandle pipe, int status, UvException error) + { + if (status < 0) + { + return; + } + + var dispatchPipe = new UvPipeHandle(Log); + // Add to the list of created pipes for disposal tracking + _createdPipes.Add(dispatchPipe); + + try + { + dispatchPipe.Init(Thread.Loop, Thread.QueueCloseHandle, true); + pipe.Accept(dispatchPipe); + + // Ensure client sends "Kestrel" before adding pipe to _dispatchPipes. + var readContext = new PipeReadContext(this); + dispatchPipe.ReadStart( + (handle, status2, state) => ((PipeReadContext)state).AllocCallback(handle, status2), + (handle, status2, state) => ((PipeReadContext)state).ReadCallback(handle, status2), + readContext); + } + catch (UvException ex) + { + dispatchPipe.Dispose(); + Log.LogError(0, ex, "ListenerPrimary.OnListenPipe"); + } + } + + protected override void DispatchConnection(UvStreamHandle socket) + { + var index = _dispatchIndex++ % (_dispatchPipes.Count + 1); + if (index == _dispatchPipes.Count) + { + base.DispatchConnection(socket); + } + else + { + DetachFromIOCP(socket); + var dispatchPipe = _dispatchPipes[index]; + var write = new UvWriteReq(Log); + try + { + write.Init(Thread); + write.Write2( + dispatchPipe, + _dummyMessage, + socket, + (write2, status, error, state) => + { + write2.Dispose(); + ((UvStreamHandle)state).Dispose(); + }, + socket); + } + catch (UvException) + { + write.Dispose(); + throw; + } + } + } + + private void DetachFromIOCP(UvHandle handle) + { + if (!_tryDetachFromIOCP) + { + return; + } + + // https://msdn.microsoft.com/en-us/library/windows/hardware/ff728840(v=vs.85).aspx + const int FileReplaceCompletionInformation = 61; + // https://msdn.microsoft.com/en-us/library/cc704588.aspx + const uint STATUS_INVALID_INFO_CLASS = 0xC0000003; + + var statusBlock = new IO_STATUS_BLOCK(); + var socket = IntPtr.Zero; + Thread.Loop.Libuv.uv_fileno(handle, ref socket); + + if (NtSetInformationFile(socket, out statusBlock, _fileCompletionInfoPtr, + (uint)Marshal.SizeOf(), FileReplaceCompletionInformation) == STATUS_INVALID_INFO_CLASS) + { + // Replacing IOCP information is only supported on Windows 8.1 or newer + _tryDetachFromIOCP = false; + } + } + + private struct IO_STATUS_BLOCK + { + uint status; + ulong information; + } + + private struct FILE_COMPLETION_INFORMATION + { + public IntPtr Port; + public IntPtr Key; + } + + [DllImport("NtDll.dll")] + private static extern uint NtSetInformationFile(IntPtr FileHandle, + out IO_STATUS_BLOCK IoStatusBlock, IntPtr FileInformation, uint Length, + int FileInformationClass); + + public override async Task DisposeAsync() + { + // Call base first so the ListenSocket gets closed and doesn't + // try to dispatch connections to closed pipes. + await base.DisposeAsync().ConfigureAwait(false); + + if (_fileCompletionInfoPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(_fileCompletionInfoPtr); + _fileCompletionInfoPtr = IntPtr.Zero; + } + + if (Thread.FatalError == null && ListenPipe != null) + { + await Thread.PostAsync(listener => + { + listener.ListenPipe.Dispose(); + + foreach (var pipe in listener._createdPipes) + { + pipe.Dispose(); + } + }, this).ConfigureAwait(false); + } + } + + private class PipeReadContext + { + private const int _bufferLength = 16; + + private readonly ListenerPrimary _listener; + private readonly byte[] _buf = new byte[_bufferLength]; + private readonly IntPtr _bufPtr; + private GCHandle _bufHandle; + private int _bytesRead; + + public PipeReadContext(ListenerPrimary listener) + { + _listener = listener; + _bufHandle = GCHandle.Alloc(_buf, GCHandleType.Pinned); + _bufPtr = _bufHandle.AddrOfPinnedObject(); + } + + public LibuvFunctions.uv_buf_t AllocCallback(UvStreamHandle dispatchPipe, int suggestedSize) + { + return dispatchPipe.Libuv.buf_init(_bufPtr + _bytesRead, _bufferLength - _bytesRead); + } + + public void ReadCallback(UvStreamHandle dispatchPipe, int status) + { + try + { + dispatchPipe.Libuv.ThrowIfErrored(status); + + _bytesRead += status; + + if (_bytesRead == _bufferLength) + { + var correctMessage = true; + + for (var i = 0; i < _bufferLength; i++) + { + if (_buf[i] != _listener._pipeMessage[i]) + { + correctMessage = false; + } + } + + if (correctMessage) + { + _listener._dispatchPipes.Add((UvPipeHandle)dispatchPipe); + dispatchPipe.ReadStop(); + _bufHandle.Free(); + } + else + { + throw new IOException("Bad data sent over Kestrel pipe."); + } + } + } + catch (Exception ex) + { + dispatchPipe.Dispose(); + _bufHandle.Free(); + _listener.Log.LogError(0, ex, "ListenerPrimary.ReadCallback"); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerSecondary.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerSecondary.cs new file mode 100644 index 0000000000..03204caeb4 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/ListenerSecondary.cs @@ -0,0 +1,200 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + /// + /// A secondary listener is delegated requests from a primary listener via a named pipe or + /// UNIX domain socket. + /// + public class ListenerSecondary : ListenerContext, IAsyncDisposable + { + private string _pipeName; + private byte[] _pipeMessage; + private IntPtr _ptr; + private LibuvFunctions.uv_buf_t _buf; + private bool _closed; + + public ListenerSecondary(LibuvTransportContext transportContext) : base(transportContext) + { + _ptr = Marshal.AllocHGlobal(4); + } + + UvPipeHandle DispatchPipe { get; set; } + + public ILibuvTrace Log => TransportContext.Log; + + public Task StartAsync( + string pipeName, + byte[] pipeMessage, + IEndPointInformation endPointInformation, + LibuvThread thread) + { + _pipeName = pipeName; + _pipeMessage = pipeMessage; + _buf = thread.Loop.Libuv.buf_init(_ptr, 4); + + EndPointInformation = endPointInformation; + Thread = thread; + DispatchPipe = new UvPipeHandle(Log); + + var tcs = new TaskCompletionSource(this, TaskCreationOptions.RunContinuationsAsynchronously); + Thread.Post(StartCallback, tcs); + return tcs.Task; + } + + private static void StartCallback(TaskCompletionSource tcs) + { + var listener = (ListenerSecondary)tcs.Task.AsyncState; + listener.StartedCallback(tcs); + } + + private void StartedCallback(TaskCompletionSource tcs) + { + var connect = new UvConnectRequest(Log); + try + { + DispatchPipe.Init(Thread.Loop, Thread.QueueCloseHandle, true); + connect.Init(Thread); + connect.Connect( + DispatchPipe, + _pipeName, + (connect2, status, error, state) => ConnectCallback(connect2, status, error, (TaskCompletionSource)state), + tcs); + } + catch (Exception ex) + { + DispatchPipe.Dispose(); + connect.Dispose(); + tcs.SetException(ex); + } + } + + private static void ConnectCallback(UvConnectRequest connect, int status, UvException error, TaskCompletionSource tcs) + { + var listener = (ListenerSecondary)tcs.Task.AsyncState; + _ = listener.ConnectedCallback(connect, status, error, tcs); + } + + private async Task ConnectedCallback(UvConnectRequest connect, int status, UvException error, TaskCompletionSource tcs) + { + connect.Dispose(); + if (error != null) + { + tcs.SetException(error); + return; + } + + var writeReq = new UvWriteReq(Log); + + try + { + DispatchPipe.ReadStart( + (handle, status2, state) => ((ListenerSecondary)state)._buf, + (handle, status2, state) => ((ListenerSecondary)state).ReadStartCallback(handle, status2), + this); + + writeReq.Init(Thread); + var result = await writeReq.WriteAsync( + DispatchPipe, + new ArraySegment>(new[] { new ArraySegment(_pipeMessage) })); + + if (result.Error != null) + { + tcs.SetException(result.Error); + } + else + { + tcs.SetResult(0); + } + } + catch (Exception ex) + { + DispatchPipe.Dispose(); + tcs.SetException(ex); + } + finally + { + writeReq.Dispose(); + } + } + + private void ReadStartCallback(UvStreamHandle handle, int status) + { + if (status < 0) + { + if (status != LibuvConstants.EOF) + { + Thread.Loop.Libuv.Check(status, out var ex); + Log.LogError(0, ex, "DispatchPipe.ReadStart"); + } + + DispatchPipe.Dispose(); + return; + } + + if (_closed || DispatchPipe.PendingCount() == 0) + { + return; + } + + var acceptSocket = CreateAcceptSocket(); + + try + { + DispatchPipe.Accept(acceptSocket); + HandleConnectionAsync(acceptSocket); + } + catch (UvException ex) when (LibuvConstants.IsConnectionReset(ex.StatusCode)) + { + Log.ConnectionReset("(null)"); + acceptSocket.Dispose(); + } + catch (UvException ex) + { + Log.LogError(0, ex, "DispatchPipe.Accept"); + acceptSocket.Dispose(); + } + } + + private void FreeBuffer() + { + var ptr = Interlocked.Exchange(ref _ptr, IntPtr.Zero); + if (ptr != IntPtr.Zero) + { + Marshal.FreeHGlobal(ptr); + } + } + + public async Task DisposeAsync() + { + // Ensure the event loop is still running. + // If the event loop isn't running and we try to wait on this Post + // to complete, then LibuvTransport will never be disposed and + // the exception that stopped the event loop will never be surfaced. + if (Thread.FatalError == null) + { + await Thread.PostAsync(listener => + { + listener.DispatchPipe.Dispose(); + listener.FreeBuffer(); + + listener._closed = true; + + }, this).ConfigureAwait(false); + } + else + { + FreeBuffer(); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/LibuvFunctions.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/LibuvFunctions.cs new file mode 100644 index 0000000000..d5452242e6 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/LibuvFunctions.cs @@ -0,0 +1,634 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class LibuvFunctions + { + public LibuvFunctions() + { + IsWindows = PlatformApis.IsWindows; + + _uv_loop_init = NativeMethods.uv_loop_init; + _uv_loop_close = NativeMethods.uv_loop_close; + _uv_run = NativeMethods.uv_run; + _uv_stop = NativeMethods.uv_stop; + _uv_ref = NativeMethods.uv_ref; + _uv_unref = NativeMethods.uv_unref; + _uv_fileno = NativeMethods.uv_fileno; + _uv_close = NativeMethods.uv_close; + _uv_async_init = NativeMethods.uv_async_init; + _uv_async_send = NativeMethods.uv_async_send; + _uv_unsafe_async_send = NativeMethods.uv_unsafe_async_send; + _uv_tcp_init = NativeMethods.uv_tcp_init; + _uv_tcp_bind = NativeMethods.uv_tcp_bind; + _uv_tcp_open = NativeMethods.uv_tcp_open; + _uv_tcp_nodelay = NativeMethods.uv_tcp_nodelay; + _uv_pipe_init = NativeMethods.uv_pipe_init; + _uv_pipe_bind = NativeMethods.uv_pipe_bind; + _uv_pipe_open = NativeMethods.uv_pipe_open; + _uv_listen = NativeMethods.uv_listen; + _uv_accept = NativeMethods.uv_accept; + _uv_pipe_connect = NativeMethods.uv_pipe_connect; + _uv_pipe_pending_count = NativeMethods.uv_pipe_pending_count; + _uv_read_start = NativeMethods.uv_read_start; + _uv_read_stop = NativeMethods.uv_read_stop; + _uv_try_write = NativeMethods.uv_try_write; + unsafe + { + _uv_write = NativeMethods.uv_write; + _uv_write2 = NativeMethods.uv_write2; + } + _uv_err_name = NativeMethods.uv_err_name; + _uv_strerror = NativeMethods.uv_strerror; + _uv_loop_size = NativeMethods.uv_loop_size; + _uv_handle_size = NativeMethods.uv_handle_size; + _uv_req_size = NativeMethods.uv_req_size; + _uv_ip4_addr = NativeMethods.uv_ip4_addr; + _uv_ip6_addr = NativeMethods.uv_ip6_addr; + _uv_tcp_getpeername = NativeMethods.uv_tcp_getpeername; + _uv_tcp_getsockname = NativeMethods.uv_tcp_getsockname; + _uv_walk = NativeMethods.uv_walk; + _uv_timer_init = NativeMethods.uv_timer_init; + _uv_timer_start = NativeMethods.uv_timer_start; + _uv_timer_stop = NativeMethods.uv_timer_stop; + _uv_now = NativeMethods.uv_now; + } + + // Second ctor that doesn't set any fields only to be used by MockLibuv + public LibuvFunctions(bool onlyForTesting) + { + } + + public readonly bool IsWindows; + + public void ThrowIfErrored(int statusCode) + { + // Note: method is explicitly small so the success case is easily inlined + if (statusCode < 0) + { + ThrowError(statusCode); + } + } + + private void ThrowError(int statusCode) + { + // Note: only has one throw block so it will marked as "Does not return" by the jit + // and not inlined into previous function, while also marking as a function + // that does not need cpu register prep to call (see: https://github.com/dotnet/coreclr/pull/6103) + throw GetError(statusCode); + } + + public void Check(int statusCode, out UvException error) + { + // Note: method is explicitly small so the success case is easily inlined + error = statusCode < 0 ? GetError(statusCode) : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private UvException GetError(int statusCode) + { + // Note: method marked as NoInlining so it doesn't bloat either of the two preceeding functions + // Check and ThrowError and alter their jit heuristics. + var errorName = err_name(statusCode); + var errorDescription = strerror(statusCode); + return new UvException("Error " + statusCode + " " + errorName + " " + errorDescription, statusCode); + } + + protected Func _uv_loop_init; + public void loop_init(UvLoopHandle handle) + { + ThrowIfErrored(_uv_loop_init(handle)); + } + + protected Func _uv_loop_close; + public void loop_close(UvLoopHandle handle) + { + handle.Validate(closed: true); + ThrowIfErrored(_uv_loop_close(handle.InternalGetHandle())); + } + + protected Func _uv_run; + public void run(UvLoopHandle handle, int mode) + { + handle.Validate(); + ThrowIfErrored(_uv_run(handle, mode)); + } + + protected Action _uv_stop; + public void stop(UvLoopHandle handle) + { + handle.Validate(); + _uv_stop(handle); + } + + protected Action _uv_ref; + public void @ref(UvHandle handle) + { + handle.Validate(); + _uv_ref(handle); + } + + protected Action _uv_unref; + public void unref(UvHandle handle) + { + handle.Validate(); + _uv_unref(handle); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + protected delegate int uv_fileno_func(UvHandle handle, ref IntPtr socket); + protected uv_fileno_func _uv_fileno; + public void uv_fileno(UvHandle handle, ref IntPtr socket) + { + handle.Validate(); + ThrowIfErrored(_uv_fileno(handle, ref socket)); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_close_cb(IntPtr handle); + protected Action _uv_close; + public void close(UvHandle handle, uv_close_cb close_cb) + { + handle.Validate(closed: true); + _uv_close(handle.InternalGetHandle(), close_cb); + } + + public void close(IntPtr handle, uv_close_cb close_cb) + { + _uv_close(handle, close_cb); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_async_cb(IntPtr handle); + protected Func _uv_async_init; + public void async_init(UvLoopHandle loop, UvAsyncHandle handle, uv_async_cb cb) + { + loop.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_async_init(loop, handle, cb)); + } + + protected Func _uv_async_send; + public void async_send(UvAsyncHandle handle) + { + ThrowIfErrored(_uv_async_send(handle)); + } + + protected Func _uv_unsafe_async_send; + public void unsafe_async_send(IntPtr handle) + { + ThrowIfErrored(_uv_unsafe_async_send(handle)); + } + + protected Func _uv_tcp_init; + public void tcp_init(UvLoopHandle loop, UvTcpHandle handle) + { + loop.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_tcp_init(loop, handle)); + } + + protected delegate int uv_tcp_bind_func(UvTcpHandle handle, ref SockAddr addr, int flags); + protected uv_tcp_bind_func _uv_tcp_bind; + public void tcp_bind(UvTcpHandle handle, ref SockAddr addr, int flags) + { + handle.Validate(); + ThrowIfErrored(_uv_tcp_bind(handle, ref addr, flags)); + } + + protected Func _uv_tcp_open; + public void tcp_open(UvTcpHandle handle, IntPtr hSocket) + { + handle.Validate(); + ThrowIfErrored(_uv_tcp_open(handle, hSocket)); + } + + protected Func _uv_tcp_nodelay; + public void tcp_nodelay(UvTcpHandle handle, bool enable) + { + handle.Validate(); + ThrowIfErrored(_uv_tcp_nodelay(handle, enable ? 1 : 0)); + } + + protected Func _uv_pipe_init; + public void pipe_init(UvLoopHandle loop, UvPipeHandle handle, bool ipc) + { + loop.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_pipe_init(loop, handle, ipc ? -1 : 0)); + } + + protected Func _uv_pipe_bind; + public void pipe_bind(UvPipeHandle handle, string name) + { + handle.Validate(); + ThrowIfErrored(_uv_pipe_bind(handle, name)); + } + + protected Func _uv_pipe_open; + public void pipe_open(UvPipeHandle handle, IntPtr hSocket) + { + handle.Validate(); + ThrowIfErrored(_uv_pipe_open(handle, hSocket)); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_connection_cb(IntPtr server, int status); + protected Func _uv_listen; + public void listen(UvStreamHandle handle, int backlog, uv_connection_cb cb) + { + handle.Validate(); + ThrowIfErrored(_uv_listen(handle, backlog, cb)); + } + + protected Func _uv_accept; + public void accept(UvStreamHandle server, UvStreamHandle client) + { + server.Validate(); + client.Validate(); + ThrowIfErrored(_uv_accept(server, client)); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_connect_cb(IntPtr req, int status); + protected Action _uv_pipe_connect; + public void pipe_connect(UvConnectRequest req, UvPipeHandle handle, string name, uv_connect_cb cb) + { + req.Validate(); + handle.Validate(); + _uv_pipe_connect(req, handle, name, cb); + } + + protected Func _uv_pipe_pending_count; + public int pipe_pending_count(UvPipeHandle handle) + { + handle.Validate(); + return _uv_pipe_pending_count(handle); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_alloc_cb(IntPtr server, int suggested_size, out uv_buf_t buf); + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_read_cb(IntPtr server, int nread, ref uv_buf_t buf); + protected Func _uv_read_start; + public void read_start(UvStreamHandle handle, uv_alloc_cb alloc_cb, uv_read_cb read_cb) + { + handle.Validate(); + ThrowIfErrored(_uv_read_start(handle, alloc_cb, read_cb)); + } + + protected Func _uv_read_stop; + public void read_stop(UvStreamHandle handle) + { + handle.Validate(); + ThrowIfErrored(_uv_read_stop(handle)); + } + + protected Func _uv_try_write; + public int try_write(UvStreamHandle handle, uv_buf_t[] bufs, int nbufs) + { + handle.Validate(); + var count = _uv_try_write(handle, bufs, nbufs); + ThrowIfErrored(count); + return count; + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_write_cb(IntPtr req, int status); + + unsafe protected delegate int uv_write_func(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, uv_write_cb cb); + protected uv_write_func _uv_write; + unsafe public void write(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, uv_write_cb cb) + { + req.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_write(req, handle, bufs, nbufs, cb)); + } + + unsafe protected delegate int uv_write2_func(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, UvStreamHandle sendHandle, uv_write_cb cb); + protected uv_write2_func _uv_write2; + unsafe public void write2(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, UvStreamHandle sendHandle, uv_write_cb cb) + { + req.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_write2(req, handle, bufs, nbufs, sendHandle, cb)); + } + + protected Func _uv_err_name; + public string err_name(int err) + { + IntPtr ptr = _uv_err_name(err); + return ptr == IntPtr.Zero ? null : Marshal.PtrToStringAnsi(ptr); + } + + protected Func _uv_strerror; + public string strerror(int err) + { + IntPtr ptr = _uv_strerror(err); + return ptr == IntPtr.Zero ? null : Marshal.PtrToStringAnsi(ptr); + } + + protected Func _uv_loop_size; + public int loop_size() + { + return _uv_loop_size(); + } + + protected Func _uv_handle_size; + public int handle_size(HandleType handleType) + { + return _uv_handle_size(handleType); + } + + protected Func _uv_req_size; + public int req_size(RequestType reqType) + { + return _uv_req_size(reqType); + } + + protected delegate int uv_ip4_addr_func(string ip, int port, out SockAddr addr); + protected uv_ip4_addr_func _uv_ip4_addr; + public void ip4_addr(string ip, int port, out SockAddr addr, out UvException error) + { + Check(_uv_ip4_addr(ip, port, out addr), out error); + } + + protected delegate int uv_ip6_addr_func(string ip, int port, out SockAddr addr); + protected uv_ip6_addr_func _uv_ip6_addr; + public void ip6_addr(string ip, int port, out SockAddr addr, out UvException error) + { + Check(_uv_ip6_addr(ip, port, out addr), out error); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_walk_cb(IntPtr handle, IntPtr arg); + protected Func _uv_walk; + public void walk(UvLoopHandle loop, uv_walk_cb walk_cb, IntPtr arg) + { + loop.Validate(); + _uv_walk(loop, walk_cb, arg); + } + + protected Func _uv_timer_init; + unsafe public void timer_init(UvLoopHandle loop, UvTimerHandle handle) + { + loop.Validate(); + handle.Validate(); + ThrowIfErrored(_uv_timer_init(loop, handle)); + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate void uv_timer_cb(IntPtr handle); + protected Func _uv_timer_start; + unsafe public void timer_start(UvTimerHandle handle, uv_timer_cb cb, long timeout, long repeat) + { + handle.Validate(); + ThrowIfErrored(_uv_timer_start(handle, cb, timeout, repeat)); + } + + protected Func _uv_timer_stop; + unsafe public void timer_stop(UvTimerHandle handle) + { + handle.Validate(); + ThrowIfErrored(_uv_timer_stop(handle)); + } + + protected Func _uv_now; + unsafe public long now(UvLoopHandle loop) + { + loop.Validate(); + return _uv_now(loop); + } + + public delegate int uv_tcp_getsockname_func(UvTcpHandle handle, out SockAddr addr, ref int namelen); + protected uv_tcp_getsockname_func _uv_tcp_getsockname; + public void tcp_getsockname(UvTcpHandle handle, out SockAddr addr, ref int namelen) + { + handle.Validate(); + ThrowIfErrored(_uv_tcp_getsockname(handle, out addr, ref namelen)); + } + + public delegate int uv_tcp_getpeername_func(UvTcpHandle handle, out SockAddr addr, ref int namelen); + protected uv_tcp_getpeername_func _uv_tcp_getpeername; + public void tcp_getpeername(UvTcpHandle handle, out SockAddr addr, ref int namelen) + { + handle.Validate(); + ThrowIfErrored(_uv_tcp_getpeername(handle, out addr, ref namelen)); + } + + public uv_buf_t buf_init(IntPtr memory, int len) + { + return new uv_buf_t(memory, len, IsWindows); + } + + public struct uv_buf_t + { + // this type represents a WSABUF struct on Windows + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms741542(v=vs.85).aspx + // and an iovec struct on *nix + // http://man7.org/linux/man-pages/man2/readv.2.html + // because the order of the fields in these structs is different, the field + // names in this type don't have meaningful symbolic names. instead, they are + // assigned in the correct order by the constructor at runtime + + private readonly IntPtr _field0; + private readonly IntPtr _field1; + + public uv_buf_t(IntPtr memory, int len, bool IsWindows) + { + if (IsWindows) + { + _field0 = (IntPtr)len; + _field1 = memory; + } + else + { + _field0 = memory; + _field1 = (IntPtr)len; + } + } + } + + public enum HandleType + { + Unknown = 0, + ASYNC, + CHECK, + FS_EVENT, + FS_POLL, + HANDLE, + IDLE, + NAMED_PIPE, + POLL, + PREPARE, + PROCESS, + STREAM, + TCP, + TIMER, + TTY, + UDP, + SIGNAL, + } + + public enum RequestType + { + Unknown = 0, + REQ, + CONNECT, + WRITE, + SHUTDOWN, + UDP_SEND, + FS, + WORK, + GETADDRINFO, + GETNAMEINFO, + } + + private static class NativeMethods + { + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_loop_init(UvLoopHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_loop_close(IntPtr a0); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_run(UvLoopHandle handle, int mode); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern void uv_stop(UvLoopHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern void uv_ref(UvHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern void uv_unref(UvHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_fileno(UvHandle handle, ref IntPtr socket); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern void uv_close(IntPtr handle, uv_close_cb close_cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_async_init(UvLoopHandle loop, UvAsyncHandle handle, uv_async_cb cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public extern static int uv_async_send(UvAsyncHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl, EntryPoint = "uv_async_send")] + public extern static int uv_unsafe_async_send(IntPtr handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_init(UvLoopHandle loop, UvTcpHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_bind(UvTcpHandle handle, ref SockAddr addr, int flags); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_open(UvTcpHandle handle, IntPtr hSocket); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_nodelay(UvTcpHandle handle, int enable); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_pipe_init(UvLoopHandle loop, UvPipeHandle handle, int ipc); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_pipe_bind(UvPipeHandle loop, string name); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_pipe_open(UvPipeHandle handle, IntPtr hSocket); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_listen(UvStreamHandle handle, int backlog, uv_connection_cb cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_accept(UvStreamHandle server, UvStreamHandle client); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi)] + public static extern void uv_pipe_connect(UvConnectRequest req, UvPipeHandle handle, string name, uv_connect_cb cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public extern static int uv_pipe_pending_count(UvPipeHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public extern static int uv_read_start(UvStreamHandle handle, uv_alloc_cb alloc_cb, uv_read_cb read_cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_read_stop(UvStreamHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_try_write(UvStreamHandle handle, uv_buf_t[] bufs, int nbufs); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern int uv_write(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, uv_write_cb cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern int uv_write2(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, UvStreamHandle sendHandle, uv_write_cb cb); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public extern static IntPtr uv_err_name(int err); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr uv_strerror(int err); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_loop_size(); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_handle_size(HandleType handleType); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_req_size(RequestType reqType); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_ip4_addr(string ip, int port, out SockAddr addr); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_ip6_addr(string ip, int port, out SockAddr addr); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_getsockname(UvTcpHandle handle, out SockAddr name, ref int namelen); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_tcp_getpeername(UvTcpHandle handle, out SockAddr name, ref int namelen); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + public static extern int uv_walk(UvLoopHandle loop, uv_walk_cb walk_cb, IntPtr arg); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern int uv_timer_init(UvLoopHandle loop, UvTimerHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern int uv_timer_start(UvTimerHandle handle, uv_timer_cb cb, long timeout, long repeat); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern int uv_timer_stop(UvTimerHandle handle); + + [DllImport("libuv", CallingConvention = CallingConvention.Cdecl)] + unsafe public static extern long uv_now(UvLoopHandle loop); + + [DllImport("WS2_32.dll", CallingConvention = CallingConvention.Winapi)] + unsafe public static extern int WSAIoctl( + IntPtr socket, + int dwIoControlCode, + int* lpvInBuffer, + uint cbInBuffer, + int* lpvOutBuffer, + int cbOutBuffer, + out uint lpcbBytesReturned, + IntPtr lpOverlapped, + IntPtr lpCompletionRoutine + ); + + [DllImport("WS2_32.dll", CallingConvention = CallingConvention.Winapi)] + public static extern int WSAGetLastError(); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/PlatformApis.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/PlatformApis.cs new file mode 100644 index 0000000000..ffc7619d73 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/PlatformApis.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public static class PlatformApis + { + public static bool IsWindows { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + + public static bool IsDarwin { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.OSX); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long VolatileRead(ref long value) + { + if (IntPtr.Size == 8) + { + return Volatile.Read(ref value); + } + else + { + // Avoid torn long reads on 32-bit + return Interlocked.Read(ref value); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/SockAddr.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/SockAddr.cs new file mode 100644 index 0000000000..5386749623 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/SockAddr.cs @@ -0,0 +1,125 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + [StructLayout(LayoutKind.Sequential)] + public struct SockAddr + { + // this type represents native memory occupied by sockaddr struct + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx + // although the c/c++ header defines it as a 2-byte short followed by a 14-byte array, + // the simplest way to reserve the same size in c# is with four nameless long values + private long _field0; + private long _field1; + private long _field2; + private long _field3; + + public SockAddr(long ignored) + { + _field0 = _field1 = _field2 = _field3 = 0; + } + + public unsafe IPEndPoint GetIPEndPoint() + { + // The bytes are represented in network byte order. + // + // Example 1: [2001:4898:e0:391:b9ef:1124:9d3e:a354]:39179 + // + // 0000 0000 0b99 0017 => The third and fourth bytes 990B is the actual port + // 9103 e000 9848 0120 => IPv6 address is represented in the 128bit field1 and field2. + // 54a3 3e9d 2411 efb9 Read these two 64-bit long from right to left byte by byte. + // 0000 0000 0000 0010 => Scope ID 0x10 (eg [::1%16]) the first 4 bytes of field3 in host byte order. + // + // Example 2: 10.135.34.141:39178 when adopt dual-stack sockets, IPv4 is mapped to IPv6 + // + // 0000 0000 0a99 0017 => The port representation are the same + // 0000 0000 0000 0000 + // 8d22 870a ffff 0000 => IPv4 occupies the last 32 bit: 0A.87.22.8d is the actual address. + // 0000 0000 0000 0000 + // + // Example 3: 10.135.34.141:12804, not dual-stack sockets + // + // 8d22 870a fd31 0002 => sa_family == AF_INET (02) + // 0000 0000 0000 0000 + // 0000 0000 0000 0000 + // 0000 0000 0000 0000 + // + // Example 4: 127.0.0.1:52798, on a Mac OS + // + // 0100 007F 3ECE 0210 => sa_family == AF_INET (02) Note that struct sockaddr on mac use + // 0000 0000 0000 0000 the second unint8 field for sa family type + // 0000 0000 0000 0000 http://www.opensource.apple.com/source/xnu/xnu-1456.1.26/bsd/sys/socket.h + // 0000 0000 0000 0000 + // + // Reference: + // - Windows: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740506(v=vs.85).aspx + // - Linux: https://github.com/torvalds/linux/blob/6a13feb9c82803e2b815eca72fa7a9f5561d7861/include/linux/socket.h + // - Linux (sin6_scope_id): https://github.com/torvalds/linux/blob/5924bbecd0267d87c24110cbe2041b5075173a25/net/sunrpc/addr.c#L82 + // - Apple: http://www.opensource.apple.com/source/xnu/xnu-1456.1.26/bsd/sys/socket.h + + // Quick calculate the port by mask the field and locate the byte 3 and byte 4 + // and then shift them to correct place to form a int. + var port = ((int)(_field0 & 0x00FF0000) >> 8) | (int)((_field0 & 0xFF000000) >> 24); + + int family = (int)_field0; + if (PlatformApis.IsDarwin) + { + // see explaination in example 4 + family = family >> 8; + } + family = family & 0xFF; + + if (family == 2) + { + // AF_INET => IPv4 + return new IPEndPoint(new IPAddress((_field0 >> 32) & 0xFFFFFFFF), port); + } + else if (IsIPv4MappedToIPv6()) + { + var ipv4bits = (_field2 >> 32) & 0x00000000FFFFFFFF; + return new IPEndPoint(new IPAddress(ipv4bits), port); + } + else + { + // otherwise IPv6 + var bytes = new byte[16]; + fixed (byte* b = bytes) + { + *((long*)b) = _field1; + *((long*)(b + 8)) = _field2; + } + + return new IPEndPoint(new IPAddress(bytes, ScopeId), port); + } + } + + public uint ScopeId + { + get + { + return (uint)_field3; + } + set + { + _field3 &= unchecked ((long)0xFFFFFFFF00000000); + _field3 |= value; + } + } + + private bool IsIPv4MappedToIPv6() + { + // If the IPAddress is an IPv4 mapped to IPv6, return the IPv4 representation instead. + // For example [::FFFF:127.0.0.1] will be transform to IPAddress of 127.0.0.1 + if (_field1 != 0) + { + return false; + } + + return (_field2 & 0xFFFFFFFF) == 0xFFFF0000; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvAsyncHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvAsyncHandle.cs new file mode 100644 index 0000000000..d68a79c4ca --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvAsyncHandle.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvAsyncHandle : UvHandle + { + private static readonly LibuvFunctions.uv_close_cb _destroyMemory = (handle) => DestroyMemory(handle); + + private static readonly LibuvFunctions.uv_async_cb _uv_async_cb = (handle) => AsyncCb(handle); + private Action _callback; + private Action, IntPtr> _queueCloseHandle; + + public UvAsyncHandle(ILibuvTrace logger) : base(logger) + { + } + + public void Init(UvLoopHandle loop, Action callback, Action, IntPtr> queueCloseHandle) + { + CreateMemory( + loop.Libuv, + loop.ThreadId, + loop.Libuv.handle_size(LibuvFunctions.HandleType.ASYNC)); + + _callback = callback; + _queueCloseHandle = queueCloseHandle; + _uv.async_init(loop, this, _uv_async_cb); + } + + public void Send() + { + _uv.async_send(this); + } + + private static void AsyncCb(IntPtr handle) + { + FromIntPtr(handle)._callback.Invoke(); + } + + protected override bool ReleaseHandle() + { + var memory = handle; + if (memory != IntPtr.Zero) + { + handle = IntPtr.Zero; + + if (Thread.CurrentThread.ManagedThreadId == ThreadId) + { + _uv.close(memory, _destroyMemory); + } + else if (_queueCloseHandle != null) + { + // This can be called from the finalizer. + // Ensure the closure doesn't reference "this". + var uv = _uv; + _queueCloseHandle(memory2 => uv.close(memory2, _destroyMemory), memory); + uv.unsafe_async_send(memory); + } + else + { + Debug.Assert(false, "UvAsyncHandle not initialized with queueCloseHandle action"); + return false; + } + } + return true; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvConnectRequest.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvConnectRequest.cs new file mode 100644 index 0000000000..75c58b5434 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvConnectRequest.cs @@ -0,0 +1,78 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + /// + /// Summary description for UvWriteRequest + /// + public class UvConnectRequest : UvRequest + { + private readonly static LibuvFunctions.uv_connect_cb _uv_connect_cb = (req, status) => UvConnectCb(req, status); + + private Action _callback; + private object _state; + + public UvConnectRequest(ILibuvTrace logger) : base (logger) + { + } + + public override void Init(LibuvThread thread) + { + DangerousInit(thread.Loop); + + base.Init(thread); + } + + public void DangerousInit(UvLoopHandle loop) + { + var requestSize = loop.Libuv.req_size(LibuvFunctions.RequestType.CONNECT); + CreateMemory( + loop.Libuv, + loop.ThreadId, + requestSize); + } + + public void Connect( + UvPipeHandle pipe, + string name, + Action callback, + object state) + { + _callback = callback; + _state = state; + + Libuv.pipe_connect(this, pipe, name, _uv_connect_cb); + } + + private static void UvConnectCb(IntPtr ptr, int status) + { + var req = FromIntPtr(ptr); + + var callback = req._callback; + req._callback = null; + + var state = req._state; + req._state = null; + + UvException error = null; + if (status < 0) + { + req.Libuv.Check(status, out error); + } + + try + { + callback(req, status, error, state); + } + catch (Exception ex) + { + req._log.LogError(0, ex, "UvConnectRequest"); + throw; + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvException.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvException.cs new file mode 100644 index 0000000000..fa7c4087fe --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvException.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvException : Exception + { + public UvException(string message, int statusCode) : base(message) + { + StatusCode = statusCode; + } + + public int StatusCode { get; } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvHandle.cs new file mode 100644 index 0000000000..0f33eee7c3 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvHandle.cs @@ -0,0 +1,66 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public abstract class UvHandle : UvMemory + { + private static readonly LibuvFunctions.uv_close_cb _destroyMemory = (handle) => DestroyMemory(handle); + private Action, IntPtr> _queueCloseHandle; + + protected UvHandle(ILibuvTrace logger) : base (logger) + { + } + + protected void CreateHandle( + LibuvFunctions uv, + int threadId, + int size, + Action, IntPtr> queueCloseHandle) + { + _queueCloseHandle = queueCloseHandle; + CreateMemory(uv, threadId, size); + } + + protected override bool ReleaseHandle() + { + var memory = handle; + if (memory != IntPtr.Zero) + { + handle = IntPtr.Zero; + + if (Thread.CurrentThread.ManagedThreadId == ThreadId) + { + _uv.close(memory, _destroyMemory); + } + else if (_queueCloseHandle != null) + { + // This can be called from the finalizer. + // Ensure the closure doesn't reference "this". + var uv = _uv; + _queueCloseHandle(memory2 => uv.close(memory2, _destroyMemory), memory); + } + else + { + Debug.Assert(false, "UvHandle not initialized with queueCloseHandle action"); + return false; + } + } + return true; + } + + public void Reference() + { + _uv.@ref(this); + } + + public void Unreference() + { + _uv.unref(this); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvLoopHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvLoopHandle.cs new file mode 100644 index 0000000000..c1a47d163d --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvLoopHandle.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvLoopHandle : UvMemory + { + public UvLoopHandle(ILibuvTrace logger) : base(logger) + { + } + + public void Init(LibuvFunctions uv) + { + CreateMemory( + uv, + Thread.CurrentThread.ManagedThreadId, + uv.loop_size()); + + _uv.loop_init(this); + } + + public void Run(int mode = 0) + { + _uv.run(this, mode); + } + + public void Stop() + { + _uv.stop(this); + } + + public long Now() + { + return _uv.now(this); + } + + unsafe protected override bool ReleaseHandle() + { + var memory = handle; + if (memory != IntPtr.Zero) + { + // loop_close clears the gcHandlePtr + var gcHandlePtr = *(IntPtr*)memory; + + _uv.loop_close(this); + handle = IntPtr.Zero; + + DestroyMemory(memory, gcHandlePtr); + } + + return true; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvMemory.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvMemory.cs new file mode 100644 index 0000000000..9454e6fc23 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvMemory.cs @@ -0,0 +1,93 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +#define TRACE + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + /// + /// Summary description for UvMemory + /// + public abstract class UvMemory : SafeHandle + { + protected LibuvFunctions _uv; + protected int _threadId; + protected readonly ILibuvTrace _log; + private readonly GCHandleType _handleType; + + protected UvMemory(ILibuvTrace logger, GCHandleType handleType = GCHandleType.Weak) : base(IntPtr.Zero, true) + { + _log = logger; + _handleType = handleType; + } + + public LibuvFunctions Libuv { get { return _uv; } } + + public override bool IsInvalid + { + get + { + return handle == IntPtr.Zero; + } + } + + public int ThreadId + { + get + { + return _threadId; + } + private set + { + _threadId = value; + } + } + + unsafe protected void CreateMemory(LibuvFunctions uv, int threadId, int size) + { + _uv = uv; + ThreadId = threadId; + + handle = Marshal.AllocCoTaskMem(size); + *(IntPtr*)handle = GCHandle.ToIntPtr(GCHandle.Alloc(this, _handleType)); + } + + unsafe protected static void DestroyMemory(IntPtr memory) + { + var gcHandlePtr = *(IntPtr*)memory; + DestroyMemory(memory, gcHandlePtr); + } + + protected static void DestroyMemory(IntPtr memory, IntPtr gcHandlePtr) + { + if (gcHandlePtr != IntPtr.Zero) + { + var gcHandle = GCHandle.FromIntPtr(gcHandlePtr); + gcHandle.Free(); + } + Marshal.FreeCoTaskMem(memory); + } + + public IntPtr InternalGetHandle() + { + return handle; + } + + public void Validate(bool closed = false) + { + Debug.Assert(closed || !IsClosed, "Handle is closed"); + Debug.Assert(!IsInvalid, "Handle is invalid"); + Debug.Assert(_threadId == Thread.CurrentThread.ManagedThreadId, "ThreadId is incorrect"); + } + + unsafe public static THandle FromIntPtr(IntPtr handle) + { + GCHandle gcHandle = GCHandle.FromIntPtr(*(IntPtr*)handle); + return (THandle)gcHandle.Target; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvPipeHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvPipeHandle.cs new file mode 100644 index 0000000000..9afdb67712 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvPipeHandle.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvPipeHandle : UvStreamHandle + { + public UvPipeHandle(ILibuvTrace logger) : base(logger) + { + } + + public void Init(UvLoopHandle loop, Action, IntPtr> queueCloseHandle, bool ipc = false) + { + CreateHandle( + loop.Libuv, + loop.ThreadId, + loop.Libuv.handle_size(LibuvFunctions.HandleType.NAMED_PIPE), queueCloseHandle); + + _uv.pipe_init(loop, this, ipc); + } + + public void Open(IntPtr fileDescriptor) + { + _uv.pipe_open(this, fileDescriptor); + } + + public void Bind(string name) + { + _uv.pipe_bind(this, name); + } + + public int PendingCount() + { + return _uv.pipe_pending_count(this); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvRequest.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvRequest.cs new file mode 100644 index 0000000000..e11b1aaadf --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvRequest.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvRequest : UvMemory + { + protected UvRequest(ILibuvTrace logger) : base(logger, GCHandleType.Normal) + { + } + + public virtual void Init(LibuvThread thread) + { +#if DEBUG + // Store weak handles to all UvRequest objects so we can do leak detection + // while running tests + thread.Requests.Add(new WeakReference(this)); +#endif + } + + protected override bool ReleaseHandle() + { + DestroyMemory(handle); + handle = IntPtr.Zero; + return true; + } + } +} + diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvStreamHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvStreamHandle.cs new file mode 100644 index 0000000000..d847512fda --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvStreamHandle.cs @@ -0,0 +1,170 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public abstract class UvStreamHandle : UvHandle + { + private readonly static LibuvFunctions.uv_connection_cb _uv_connection_cb = (handle, status) => UvConnectionCb(handle, status); + // Ref and out lamda params must be explicitly typed + private readonly static LibuvFunctions.uv_alloc_cb _uv_alloc_cb = (IntPtr handle, int suggested_size, out LibuvFunctions.uv_buf_t buf) => UvAllocCb(handle, suggested_size, out buf); + private readonly static LibuvFunctions.uv_read_cb _uv_read_cb = (IntPtr handle, int status, ref LibuvFunctions.uv_buf_t buf) => UvReadCb(handle, status, ref buf); + + private Action _listenCallback; + private object _listenState; + private GCHandle _listenVitality; + + private Func _allocCallback; + private Action _readCallback; + private object _readState; + private GCHandle _readVitality; + + protected UvStreamHandle(ILibuvTrace logger) : base(logger) + { + } + + protected override bool ReleaseHandle() + { + if (_listenVitality.IsAllocated) + { + _listenVitality.Free(); + } + if (_readVitality.IsAllocated) + { + _readVitality.Free(); + } + return base.ReleaseHandle(); + } + + public void Listen(int backlog, Action callback, object state) + { + if (_listenVitality.IsAllocated) + { + throw new InvalidOperationException("TODO: Listen may not be called more than once"); + } + try + { + _listenCallback = callback; + _listenState = state; + _listenVitality = GCHandle.Alloc(this, GCHandleType.Normal); + _uv.listen(this, backlog, _uv_connection_cb); + } + catch + { + _listenCallback = null; + _listenState = null; + if (_listenVitality.IsAllocated) + { + _listenVitality.Free(); + } + throw; + } + } + + public void Accept(UvStreamHandle handle) + { + _uv.accept(this, handle); + } + + public void ReadStart( + Func allocCallback, + Action readCallback, + object state) + { + if (_readVitality.IsAllocated) + { + throw new InvalidOperationException("TODO: ReadStop must be called before ReadStart may be called again"); + } + + try + { + _allocCallback = allocCallback; + _readCallback = readCallback; + _readState = state; + _readVitality = GCHandle.Alloc(this, GCHandleType.Normal); + _uv.read_start(this, _uv_alloc_cb, _uv_read_cb); + } + catch + { + _allocCallback = null; + _readCallback = null; + _readState = null; + if (_readVitality.IsAllocated) + { + _readVitality.Free(); + } + throw; + } + } + + // UvStreamHandle.ReadStop() should be idempotent to match uv_read_stop() + public void ReadStop() + { + if (_readVitality.IsAllocated) + { + _readVitality.Free(); + } + _allocCallback = null; + _readCallback = null; + _readState = null; + _uv.read_stop(this); + } + + public int TryWrite(LibuvFunctions.uv_buf_t buf) + { + return _uv.try_write(this, new[] { buf }, 1); + } + + private static void UvConnectionCb(IntPtr handle, int status) + { + var stream = FromIntPtr(handle); + + stream.Libuv.Check(status, out var error); + + try + { + stream._listenCallback(stream, status, error, stream._listenState); + } + catch (Exception ex) + { + stream._log.LogError(0, ex, "UvConnectionCb"); + throw; + } + } + + private static void UvAllocCb(IntPtr handle, int suggested_size, out LibuvFunctions.uv_buf_t buf) + { + var stream = FromIntPtr(handle); + try + { + buf = stream._allocCallback(stream, suggested_size, stream._readState); + } + catch (Exception ex) + { + stream._log.LogError(0, ex, "UvAllocCb"); + buf = stream.Libuv.buf_init(IntPtr.Zero, 0); + throw; + } + } + + private static void UvReadCb(IntPtr handle, int status, ref LibuvFunctions.uv_buf_t buf) + { + var stream = FromIntPtr(handle); + + try + { + stream._readCallback(stream, status, stream._readState); + } + catch (Exception ex) + { + stream._log.LogError(0, ex, "UbReadCb"); + throw; + } + } + + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTcpHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTcpHandle.cs new file mode 100644 index 0000000000..7bb6e0e908 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTcpHandle.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvTcpHandle : UvStreamHandle + { + public UvTcpHandle(ILibuvTrace logger) : base(logger) + { + } + + public void Init(UvLoopHandle loop, Action, IntPtr> queueCloseHandle) + { + CreateHandle( + loop.Libuv, + loop.ThreadId, + loop.Libuv.handle_size(LibuvFunctions.HandleType.TCP), queueCloseHandle); + + _uv.tcp_init(loop, this); + } + + public void Open(IntPtr fileDescriptor) + { + _uv.tcp_open(this, fileDescriptor); + } + + public void Bind(IPEndPoint endPoint) + { + SockAddr addr; + var addressText = endPoint.Address.ToString(); + + _uv.ip4_addr(addressText, endPoint.Port, out addr, out var error1); + + if (error1 != null) + { + _uv.ip6_addr(addressText, endPoint.Port, out addr, out var error2); + if (error2 != null) + { + throw error1; + } + + if (endPoint.Address.ScopeId != addr.ScopeId) + { + // IPAddress.ScopeId cannot be less than 0 or greater than 0xFFFFFFFF + // https://msdn.microsoft.com/en-us/library/system.net.ipaddress.scopeid(v=vs.110).aspx + addr.ScopeId = (uint)endPoint.Address.ScopeId; + } + } + + _uv.tcp_bind(this, ref addr, 0); + } + + public IPEndPoint GetPeerIPEndPoint() + { + SockAddr socketAddress; + int namelen = Marshal.SizeOf(); + _uv.tcp_getpeername(this, out socketAddress, ref namelen); + + return socketAddress.GetIPEndPoint(); + } + + public IPEndPoint GetSockIPEndPoint() + { + SockAddr socketAddress; + int namelen = Marshal.SizeOf(); + _uv.tcp_getsockname(this, out socketAddress, ref namelen); + + return socketAddress.GetIPEndPoint(); + } + + public void NoDelay(bool enable) + { + _uv.tcp_nodelay(this, enable); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTimerHandle.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTimerHandle.cs new file mode 100644 index 0000000000..85547e7d69 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvTimerHandle.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + public class UvTimerHandle : UvHandle + { + private readonly static LibuvFunctions.uv_timer_cb _uv_timer_cb = UvTimerCb; + + private Action _callback; + + public UvTimerHandle(ILibuvTrace logger) : base(logger) + { + } + + public void Init(UvLoopHandle loop, Action, IntPtr> queueCloseHandle) + { + CreateHandle( + loop.Libuv, + loop.ThreadId, + loop.Libuv.handle_size(LibuvFunctions.HandleType.TIMER), + queueCloseHandle); + + _uv.timer_init(loop, this); + } + + public void Start(Action callback, long timeout, long repeat) + { + _callback = callback; + _uv.timer_start(this, _uv_timer_cb, timeout, repeat); + } + + public void Stop() + { + _uv.timer_stop(this); + } + + private static void UvTimerCb(IntPtr handle) + { + var timer = FromIntPtr(handle); + + try + { + timer._callback(timer); + } + catch (Exception ex) + { + timer._log.LogError(0, ex, nameof(UvTimerCb)); + throw; + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvWriteReq.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvWriteReq.cs new file mode 100644 index 0000000000..e28c318616 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/Networking/UvWriteReq.cs @@ -0,0 +1,256 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking +{ + /// + /// Summary description for UvWriteRequest + /// + public class UvWriteReq : UvRequest + { + private static readonly LibuvFunctions.uv_write_cb _uv_write_cb = (IntPtr ptr, int status) => UvWriteCb(ptr, status); + + private IntPtr _bufs; + + private Action _callback; + private object _state; + private const int BUFFER_COUNT = 4; + + private LibuvAwaitable _awaitable = new LibuvAwaitable(); + private List _pins = new List(BUFFER_COUNT + 1); + private List _handles = new List(BUFFER_COUNT + 1); + + public UvWriteReq(ILibuvTrace logger) : base(logger) + { + } + + public override void Init(LibuvThread thread) + { + DangerousInit(thread.Loop); + + base.Init(thread); + } + + public void DangerousInit(UvLoopHandle loop) + { + var requestSize = loop.Libuv.req_size(LibuvFunctions.RequestType.WRITE); + var bufferSize = Marshal.SizeOf() * BUFFER_COUNT; + CreateMemory( + loop.Libuv, + loop.ThreadId, + requestSize + bufferSize); + _bufs = handle + requestSize; + } + + public LibuvAwaitable WriteAsync(UvStreamHandle handle, ReadOnlySequence buffer) + { + Write(handle, buffer, LibuvAwaitable.Callback, _awaitable); + return _awaitable; + } + + public LibuvAwaitable WriteAsync(UvStreamHandle handle, ArraySegment> bufs) + { + Write(handle, bufs, LibuvAwaitable.Callback, _awaitable); + return _awaitable; + } + + private unsafe void Write( + UvStreamHandle handle, + ReadOnlySequence buffer, + Action callback, + object state) + { + try + { + var nBuffers = 0; + if (buffer.IsSingleSegment) + { + nBuffers = 1; + } + else + { + foreach (var _ in buffer) + { + nBuffers++; + } + } + + var pBuffers = (LibuvFunctions.uv_buf_t*)_bufs; + if (nBuffers > BUFFER_COUNT) + { + // create and pin buffer array when it's larger than the pre-allocated one + var bufArray = new LibuvFunctions.uv_buf_t[nBuffers]; + var gcHandle = GCHandle.Alloc(bufArray, GCHandleType.Pinned); + _pins.Add(gcHandle); + pBuffers = (LibuvFunctions.uv_buf_t*)gcHandle.AddrOfPinnedObject(); + } + + if (nBuffers == 1) + { + var memory = buffer.First; + var memoryHandle = memory.Pin(); + _handles.Add(memoryHandle); + + // Fast path for single buffer + pBuffers[0] = Libuv.buf_init( + (IntPtr)memoryHandle.Pointer, + memory.Length); + } + else + { + var index = 0; + foreach (var memory in buffer) + { + // This won't actually pin the buffer since we're already using pinned memory + var memoryHandle = memory.Pin(); + _handles.Add(memoryHandle); + + // create and pin each segment being written + pBuffers[index] = Libuv.buf_init( + (IntPtr)memoryHandle.Pointer, + memory.Length); + index++; + } + } + + _callback = callback; + _state = state; + _uv.write(this, handle, pBuffers, nBuffers, _uv_write_cb); + } + catch + { + _callback = null; + _state = null; + UnpinGcHandles(); + throw; + } + } + + private void Write( + UvStreamHandle handle, + ArraySegment> bufs, + Action callback, + object state) + { + WriteArraySegmentInternal(handle, bufs, sendHandle: null, callback: callback, state: state); + } + + public void Write2( + UvStreamHandle handle, + ArraySegment> bufs, + UvStreamHandle sendHandle, + Action callback, + object state) + { + WriteArraySegmentInternal(handle, bufs, sendHandle, callback, state); + } + + private unsafe void WriteArraySegmentInternal( + UvStreamHandle handle, + ArraySegment> bufs, + UvStreamHandle sendHandle, + Action callback, + object state) + { + try + { + var pBuffers = (LibuvFunctions.uv_buf_t*)_bufs; + var nBuffers = bufs.Count; + if (nBuffers > BUFFER_COUNT) + { + // create and pin buffer array when it's larger than the pre-allocated one + var bufArray = new LibuvFunctions.uv_buf_t[nBuffers]; + var gcHandle = GCHandle.Alloc(bufArray, GCHandleType.Pinned); + _pins.Add(gcHandle); + pBuffers = (LibuvFunctions.uv_buf_t*)gcHandle.AddrOfPinnedObject(); + } + + for (var index = 0; index < nBuffers; index++) + { + // create and pin each segment being written + var buf = bufs.Array[bufs.Offset + index]; + + var gcHandle = GCHandle.Alloc(buf.Array, GCHandleType.Pinned); + _pins.Add(gcHandle); + pBuffers[index] = Libuv.buf_init( + gcHandle.AddrOfPinnedObject() + buf.Offset, + buf.Count); + } + + _callback = callback; + _state = state; + + if (sendHandle == null) + { + _uv.write(this, handle, pBuffers, nBuffers, _uv_write_cb); + } + else + { + _uv.write2(this, handle, pBuffers, nBuffers, sendHandle, _uv_write_cb); + } + } + catch + { + _callback = null; + _state = null; + UnpinGcHandles(); + throw; + } + } + + // Safe handle has instance method called Unpin + // so using UnpinGcHandles to avoid conflict + private void UnpinGcHandles() + { + var pinList = _pins; + var count = pinList.Count; + for (var i = 0; i < count; i++) + { + pinList[i].Free(); + } + pinList.Clear(); + + var handleList = _handles; + count = handleList.Count; + for (var i = 0; i < count; i++) + { + handleList[i].Dispose(); + } + handleList.Clear(); + } + + private static void UvWriteCb(IntPtr ptr, int status) + { + var req = FromIntPtr(ptr); + req.UnpinGcHandles(); + + var callback = req._callback; + req._callback = null; + + var state = req._state; + req._state = null; + + UvException error = null; + if (status < 0) + { + req.Libuv.Check(status, out error); + } + + try + { + callback(req, status, error, state); + } + catch (Exception ex) + { + req._log.LogError(0, ex, "UvWriteCb"); + throw; + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Internal/WriteReqPool.cs b/src/Servers/Kestrel/Transport.Libuv/src/Internal/WriteReqPool.cs new file mode 100644 index 0000000000..cdc612ee4f --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Internal/WriteReqPool.cs @@ -0,0 +1,76 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal +{ + public class WriteReqPool + { + private const int _maxPooledWriteReqs = 1024; + + private readonly LibuvThread _thread; + private readonly Queue _pool = new Queue(_maxPooledWriteReqs); + private readonly ILibuvTrace _log; + private bool _disposed; + + public WriteReqPool(LibuvThread thread, ILibuvTrace log) + { + _thread = thread; + _log = log; + } + + public UvWriteReq Allocate() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().Name); + } + + UvWriteReq req; + if (_pool.Count > 0) + { + req = _pool.Dequeue(); + } + else + { + req = new UvWriteReq(_log); + req.Init(_thread); + } + + return req; + } + + public void Return(UvWriteReq req) + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().Name); + } + + if (_pool.Count < _maxPooledWriteReqs) + { + _pool.Enqueue(req); + } + else + { + req.Dispose(); + } + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + + while (_pool.Count > 0) + { + _pool.Dequeue().Dispose(); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/LibuvTransportOptions.cs b/src/Servers/Kestrel/Transport.Libuv/src/LibuvTransportOptions.cs new file mode 100644 index 0000000000..db157d39dd --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/LibuvTransportOptions.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv +{ + /// + /// Provides programmatic configuration of Libuv transport features. + /// + public class LibuvTransportOptions + { + /// + /// The number of libuv I/O threads used to process requests. + /// + /// + /// Defaults to half of rounded down and clamped between 1 and 16. + /// + public int ThreadCount { get; set; } = ProcessorThreadCount; + + private static int ProcessorThreadCount + { + get + { + // Actual core count would be a better number + // rather than logical cores which includes hyper-threaded cores. + // Divide by 2 for hyper-threading, and good defaults (still need threads to do webserving). + var threadCount = Environment.ProcessorCount >> 1; + + if (threadCount < 1) + { + // Ensure shifted value is at least one + return 1; + } + + if (threadCount > 16) + { + // Receive Side Scaling RSS Processor count currently maxes out at 16 + // would be better to check the NIC's current hardware queues; but xplat... + return 16; + } + + return threadCount; + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.csproj b/src/Servers/Kestrel/Transport.Libuv/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.csproj new file mode 100644 index 0000000000..001e97cdb6 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.csproj @@ -0,0 +1,20 @@ + + + + Libuv transport for the ASP.NET Core Kestrel cross-platform web server. + netstandard2.0 + true + aspnetcore;kestrel + true + CS1591;$(NoWarn) + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Transport.Libuv/src/WebHostBuilderLibuvExtensions.cs b/src/Servers/Kestrel/Transport.Libuv/src/WebHostBuilderLibuvExtensions.cs new file mode 100644 index 0000000000..386d0b6679 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/WebHostBuilderLibuvExtensions.cs @@ -0,0 +1,51 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class WebHostBuilderLibuvExtensions + { + /// + /// Specify Libuv as the transport to be used by Kestrel. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseLibuv(this IWebHostBuilder hostBuilder) + { + return hostBuilder.ConfigureServices(services => + { + services.AddSingleton(); + }); + } + + /// + /// Specify Libuv as the transport to be used by Kestrel. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// A callback to configure Libuv options. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseLibuv(this IWebHostBuilder hostBuilder, Action configureOptions) + { + return hostBuilder.UseLibuv().ConfigureServices(services => + { + services.Configure(configureOptions); + }); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/src/baseline.netcore.json b/src/Servers/Kestrel/Transport.Libuv/src/baseline.netcore.json new file mode 100644 index 0000000000..b7c1b7e61b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/baseline.netcore.json @@ -0,0 +1,267 @@ +{ + "AssemblyIdentity": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv, Version=2.0.2.0, Culture=neutral, PublicKeyToken=adb9793829ddae60", + "Types": [ + { + "Name": "Microsoft.AspNetCore.Hosting.WebHostBuilderLibuvExtensions", + "Visibility": "Public", + "Kind": "Class", + "Abstract": true, + "Static": true, + "Sealed": true, + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "UseLibuv", + "Parameters": [ + { + "Name": "hostBuilder", + "Type": "Microsoft.AspNetCore.Hosting.IWebHostBuilder" + } + ], + "ReturnType": "Microsoft.AspNetCore.Hosting.IWebHostBuilder", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UseLibuv", + "Parameters": [ + { + "Name": "hostBuilder", + "Type": "Microsoft.AspNetCore.Hosting.IWebHostBuilder" + }, + { + "Name": "configureOptions", + "Type": "System.Action" + } + ], + "ReturnType": "Microsoft.AspNetCore.Hosting.IWebHostBuilder", + "Static": true, + "Extension": true, + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransport", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [ + "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport" + ], + "Members": [ + { + "Kind": "Method", + "Name": "get_Libuv", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking.LibuvFunctions", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_TransportContext", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.LibuvTransportContext", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Threads", + "Parameters": [], + "ReturnType": "System.Collections.Generic.List", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_AppLifetime", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Hosting.IApplicationLifetime", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_Log", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.ILibuvTrace", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "get_TransportOptions", + "Parameters": [], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransportOptions", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "StopAsync", + "Parameters": [], + "ReturnType": "System.Threading.Tasks.Task", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "BindAsync", + "Parameters": [], + "ReturnType": "System.Threading.Tasks.Task", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "UnbindAsync", + "Parameters": [], + "ReturnType": "System.Threading.Tasks.Task", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [ + { + "Name": "context", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.LibuvTransportContext" + }, + { + "Name": "endPointInformation", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation" + } + ], + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [ + { + "Name": "uv", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking.LibuvFunctions" + }, + { + "Name": "context", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.LibuvTransportContext" + }, + { + "Name": "endPointInformation", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation" + } + ], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransportFactory", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [ + "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportFactory" + ], + "Members": [ + { + "Kind": "Method", + "Name": "Create", + "Parameters": [ + { + "Name": "endPointInformation", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IEndPointInformation" + }, + { + "Name": "handler", + "Type": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.IConnectionHandler" + } + ], + "ReturnType": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport", + "Sealed": true, + "Virtual": true, + "ImplementedInterface": "Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportFactory", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [ + { + "Name": "options", + "Type": "Microsoft.Extensions.Options.IOptions" + }, + { + "Name": "applicationLifetime", + "Type": "Microsoft.AspNetCore.Hosting.IApplicationLifetime" + }, + { + "Name": "loggerFactory", + "Type": "Microsoft.Extensions.Logging.ILoggerFactory" + } + ], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + }, + { + "Name": "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransportOptions", + "Visibility": "Public", + "Kind": "Class", + "ImplementedInterfaces": [], + "Members": [ + { + "Kind": "Method", + "Name": "get_ThreadCount", + "Parameters": [], + "ReturnType": "System.Int32", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Method", + "Name": "set_ThreadCount", + "Parameters": [ + { + "Name": "value", + "Type": "System.Int32" + } + ], + "ReturnType": "System.Void", + "Visibility": "Public", + "GenericParameter": [] + }, + { + "Kind": "Constructor", + "Name": ".ctor", + "Parameters": [], + "Visibility": "Public", + "GenericParameter": [] + } + ], + "GenericParameters": [] + } + ] +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/src/breakingchanges.netcore.json b/src/Servers/Kestrel/Transport.Libuv/src/breakingchanges.netcore.json new file mode 100644 index 0000000000..bb70d76ea8 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/src/breakingchanges.netcore.json @@ -0,0 +1,10 @@ +[ + { + "TypeId": "public class Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransport : Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransport", + "Kind": "Removal" + }, + { + "TypeId": "public class Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.LibuvTransportFactory : Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal.ITransportFactory", + "Kind": "Removal" + } +] diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvConnectionTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvConnectionTests.cs new file mode 100644 index 0000000000..0184cdb86f --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvConnectionTests.cs @@ -0,0 +1,235 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvConnectionTests + { + [Fact] + public async Task DoesNotEndConnectionOnZeroRead() + { + var mockConnectionDispatcher = new MockConnectionDispatcher(); + var mockLibuv = new MockLibuv(); + var transportContext = new TestLibuvTransportContext() { ConnectionDispatcher = mockConnectionDispatcher }; + var transport = new LibuvTransport(mockLibuv, transportContext, null); + var thread = new LibuvThread(transport); + + try + { + await thread.StartAsync(); + await thread.PostAsync(_ => + { + var listenerContext = new ListenerContext(transportContext) + { + Thread = thread + }; + var socket = new MockSocket(mockLibuv, Thread.CurrentThread.ManagedThreadId, transportContext.Log); + var connection = new LibuvConnection(socket, listenerContext.TransportContext.Log, thread, null, null); + listenerContext.TransportContext.ConnectionDispatcher.OnConnection(connection); + _ = connection.Start(); + + mockLibuv.AllocCallback(socket.InternalGetHandle(), 2048, out var ignored); + mockLibuv.ReadCallback(socket.InternalGetHandle(), 0, ref ignored); + }, (object)null); + + var readAwaitable = mockConnectionDispatcher.Input.Reader.ReadAsync(); + Assert.False(readAwaitable.IsCompleted); + } + finally + { + await thread.StopAsync(TimeSpan.FromSeconds(5)); + } + } + + [Fact] + public async Task ConnectionDoesNotResumeAfterSocketCloseIfBackpressureIsApplied() + { + var mockConnectionDispatcher = new MockConnectionDispatcher(); + var mockLibuv = new MockLibuv(); + var transportContext = new TestLibuvTransportContext() { ConnectionDispatcher = mockConnectionDispatcher }; + var transport = new LibuvTransport(mockLibuv, transportContext, null); + var thread = new LibuvThread(transport); + mockConnectionDispatcher.InputOptions = pool => + new PipeOptions( + pool: pool, + pauseWriterThreshold: 3, + readerScheduler: PipeScheduler.Inline, + writerScheduler: PipeScheduler.Inline, + useSynchronizationContext: false); + + // We don't set the output writer scheduler here since we want to run the callback inline + + mockConnectionDispatcher.OutputOptions = pool => new PipeOptions(pool: pool, readerScheduler: thread, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + + + Task connectionTask = null; + try + { + await thread.StartAsync(); + + // Write enough to make sure back pressure will be applied + await thread.PostAsync(_ => + { + var listenerContext = new ListenerContext(transportContext) + { + Thread = thread + }; + var socket = new MockSocket(mockLibuv, Thread.CurrentThread.ManagedThreadId, transportContext.Log); + var connection = new LibuvConnection(socket, listenerContext.TransportContext.Log, thread, null, null); + listenerContext.TransportContext.ConnectionDispatcher.OnConnection(connection); + connectionTask = connection.Start(); + + mockLibuv.AllocCallback(socket.InternalGetHandle(), 2048, out var ignored); + mockLibuv.ReadCallback(socket.InternalGetHandle(), 5, ref ignored); + + }, null); + + // Now assert that we removed the callback from libuv to stop reading + Assert.Null(mockLibuv.AllocCallback); + Assert.Null(mockLibuv.ReadCallback); + + // Now complete the output writer so that the connection closes + mockConnectionDispatcher.Output.Writer.Complete(); + + await connectionTask.DefaultTimeout(); + + // Assert that we don't try to start reading + Assert.Null(mockLibuv.AllocCallback); + Assert.Null(mockLibuv.ReadCallback); + } + finally + { + await thread.StopAsync(TimeSpan.FromSeconds(5)); + } + } + + [Fact] + public async Task ConnectionDoesNotResumeAfterReadCallbackScheduledAndSocketCloseIfBackpressureIsApplied() + { + var mockConnectionDispatcher = new MockConnectionDispatcher(); + var mockLibuv = new MockLibuv(); + var transportContext = new TestLibuvTransportContext() { ConnectionDispatcher = mockConnectionDispatcher }; + var transport = new LibuvTransport(mockLibuv, transportContext, null); + var thread = new LibuvThread(transport); + var mockScheduler = new Mock(); + Action backPressure = null; + mockScheduler.Setup(m => m.Schedule(It.IsAny>(), It.IsAny())).Callback, object>((a, o) => + { + backPressure = () => a(o); + }); + mockConnectionDispatcher.InputOptions = pool => + new PipeOptions( + pool: pool, + pauseWriterThreshold: 3, + resumeWriterThreshold: 3, + writerScheduler: mockScheduler.Object, + readerScheduler: PipeScheduler.Inline, + useSynchronizationContext: false); + + mockConnectionDispatcher.OutputOptions = pool => new PipeOptions(pool: pool, readerScheduler: thread, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + + Task connectionTask = null; + try + { + await thread.StartAsync(); + + // Write enough to make sure back pressure will be applied + await thread.PostAsync(_ => + { + var listenerContext = new ListenerContext(transportContext) + { + Thread = thread + }; + var socket = new MockSocket(mockLibuv, Thread.CurrentThread.ManagedThreadId, transportContext.Log); + var connection = new LibuvConnection(socket, listenerContext.TransportContext.Log, thread, null, null); + listenerContext.TransportContext.ConnectionDispatcher.OnConnection(connection); + connectionTask = connection.Start(); + + mockLibuv.AllocCallback(socket.InternalGetHandle(), 2048, out var ignored); + mockLibuv.ReadCallback(socket.InternalGetHandle(), 5, ref ignored); + + }, null); + + // Now assert that we removed the callback from libuv to stop reading + Assert.Null(mockLibuv.AllocCallback); + Assert.Null(mockLibuv.ReadCallback); + + // Now release backpressure by reading the input + var result = await mockConnectionDispatcher.Input.Reader.ReadAsync(); + // Calling advance will call into our custom scheduler that captures the back pressure + // callback + mockConnectionDispatcher.Input.Reader.AdvanceTo(result.Buffer.End); + + // Cancel the current pending flush + mockConnectionDispatcher.Input.Writer.CancelPendingFlush(); + + // Now release the back pressure + await thread.PostAsync(a => a(), backPressure); + + // Assert that we don't try to start reading since the write was cancelled + Assert.Null(mockLibuv.AllocCallback); + Assert.Null(mockLibuv.ReadCallback); + + // Now complete the output writer and wait for the connection to close + mockConnectionDispatcher.Output.Writer.Complete(); + + await connectionTask.DefaultTimeout(); + + // Assert that we don't try to start reading + Assert.Null(mockLibuv.AllocCallback); + Assert.Null(mockLibuv.ReadCallback); + } + finally + { + await thread.StopAsync(TimeSpan.FromSeconds(5)); + } + } + + [Fact] + public async Task DoesNotThrowIfOnReadCallbackCalledWithEOFButAllocCallbackNotCalled() + { + var mockConnectionDispatcher = new MockConnectionDispatcher(); + var mockLibuv = new MockLibuv(); + var transportContext = new TestLibuvTransportContext() { ConnectionDispatcher = mockConnectionDispatcher }; + var transport = new LibuvTransport(mockLibuv, transportContext, null); + var thread = new LibuvThread(transport); + + try + { + await thread.StartAsync(); + await thread.PostAsync(_ => + { + var listenerContext = new ListenerContext(transportContext) + { + Thread = thread + }; + var socket = new MockSocket(mockLibuv, Thread.CurrentThread.ManagedThreadId, transportContext.Log); + var connection = new LibuvConnection(socket, listenerContext.TransportContext.Log, thread, null, null); + listenerContext.TransportContext.ConnectionDispatcher.OnConnection(connection); + _ = connection.Start(); + + var ignored = new LibuvFunctions.uv_buf_t(); + mockLibuv.ReadCallback(socket.InternalGetHandle(), TestConstants.EOF, ref ignored); + }, (object)null); + + var readAwaitable = await mockConnectionDispatcher.Input.Reader.ReadAsync(); + Assert.True(readAwaitable.IsCompleted); + } + finally + { + await thread.StopAsync(TimeSpan.FromSeconds(5)); + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvOutputConsumerTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvOutputConsumerTests.cs new file mode 100644 index 0000000000..d2f6a4fb89 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvOutputConsumerTests.cs @@ -0,0 +1,779 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvOutputConsumerTests : IDisposable + { + private readonly MemoryPool _memoryPool; + private readonly MockLibuv _mockLibuv; + private readonly LibuvThread _libuvThread; + + public static TheoryData MaxResponseBufferSizeData => new TheoryData + { + new KestrelServerOptions().Limits.MaxResponseBufferSize, 0, 1024, 1024 * 1024, null + }; + + public static TheoryData PositiveMaxResponseBufferSizeData => new TheoryData + { + (int)new KestrelServerOptions().Limits.MaxResponseBufferSize, 1024, (1024 * 1024) + 1 + }; + + public LibuvOutputConsumerTests() + { + _memoryPool = KestrelMemoryPool.Create(); + _mockLibuv = new MockLibuv(); + + var libuvTransport = new LibuvTransport(_mockLibuv, new TestLibuvTransportContext(), new ListenOptions((ulong)0)); + _libuvThread = new LibuvThread(libuvTransport, maxLoops: 1); + _libuvThread.StartAsync().Wait(); + } + + public void Dispose() + { + _libuvThread.StopAsync(TimeSpan.FromSeconds(5)).Wait(); + _memoryPool.Dispose(); + } + + [Theory] + [MemberData(nameof(MaxResponseBufferSizeData))] + public async Task CanWrite1MB(long? maxResponseBufferSize) + { + // This test was added because when initially implementing write-behind buffering in + // SocketOutput, the write callback would never be invoked for writes larger than + // maxResponseBufferSize even after the write actually completed. + + // ConnectionHandler will set Pause/ResumeWriterThreshold to zero when MaxResponseBufferSize is null. + // This is verified in PipeOptionsTests.OutputPipeOptionsConfiguredCorrectly. + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize ?? 0, + resumeWriterThreshold: maxResponseBufferSize ?? 0, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + // At least one run of this test should have a MaxResponseBufferSize < 1 MB. + var bufferSize = 1024 * 1024; + var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); + + // Act + var writeTask = outputProducer.WriteDataAsync(buffer); + + // Assert + await writeTask.DefaultTimeout(); + } + } + + [Fact] + public async Task NullMaxResponseBufferSizeAllowsUnlimitedBuffer() + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + // ConnectionHandler will set Pause/ResumeWriterThreshold to zero when MaxResponseBufferSize is null. + // This is verified in PipeOptionsTests.OutputPipeOptionsConfiguredCorrectly. + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: 0, + resumeWriterThreshold: 0, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + // Don't want to allocate anything too huge for perf. This is at least larger than the default buffer. + var bufferSize = 1024 * 1024; + var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); + + // Act + var writeTask = outputProducer.WriteDataAsync(buffer); + + // Assert + await writeTask.DefaultTimeout(); + + // Cleanup + outputProducer.Dispose(); + + // Wait for all writes to complete so the completeQueue isn't modified during enumeration. + await _mockLibuv.OnPostTask; + + // Drain the write queue + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + } + + [Fact] + public async Task ZeroMaxResponseBufferSizeDisablesBuffering() + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + // ConnectionHandler will set Pause/ResumeWriterThreshold to 1 when MaxResponseBufferSize is zero. + // This is verified in PipeOptionsTests.OutputPipeOptionsConfiguredCorrectly. + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: 1, + resumeWriterThreshold: 1, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = 1; + var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); + + // Act + var writeTask = outputProducer.WriteDataAsync(buffer); + + // Assert + Assert.False(writeTask.IsCompleted); + + // Act + await _mockLibuv.OnPostTask; + + // Finishing the write should allow the task to complete. + Assert.True(completeQueue.TryDequeue(out var triggerNextCompleted)); + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + + // Assert + await writeTask.DefaultTimeout(); + + // Cleanup + outputProducer.Dispose(); + + // Wait for all writes to complete so the completeQueue isn't modified during enumeration. + await _mockLibuv.OnPostTask; + + // Drain the write queue + while (completeQueue.TryDequeue(out triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task WritesDontCompleteImmediatelyWhenTooManyBytesAreAlreadyBuffered(int maxResponseBufferSize) + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = maxResponseBufferSize - 1; + var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); + + // Act + var writeTask1 = outputProducer.WriteDataAsync(buffer); + + // Assert + // The first write should pre-complete since it is <= _maxBytesPreCompleted. + Assert.Equal(TaskStatus.RanToCompletion, writeTask1.Status); + + // Act + var writeTask2 = outputProducer.WriteDataAsync(buffer); + await _mockLibuv.OnPostTask; + + // Assert + // Too many bytes are already pre-completed for the second write to pre-complete. + Assert.False(writeTask2.IsCompleted); + + // Act + Assert.True(completeQueue.TryDequeue(out var triggerNextCompleted)); + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + + // Finishing the first write should allow the second write to pre-complete. + await writeTask2.DefaultTimeout(); + + // Cleanup + outputProducer.Dispose(); + + // Wait for all writes to complete so the completeQueue isn't modified during enumeration. + await _mockLibuv.OnPostTask; + + // Drain the write queue + while (completeQueue.TryDequeue(out triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task WritesDontCompleteImmediatelyWhenTooManyBytesIncludingNonImmediateAreAlreadyBuffered(int maxResponseBufferSize) + { + await Task.Run(async () => + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = maxResponseBufferSize / 2; + var data = new byte[bufferSize]; + var halfWriteBehindBuffer = new ArraySegment(data, 0, bufferSize); + + // Act + var writeTask1 = outputProducer.WriteDataAsync(halfWriteBehindBuffer); + + // Assert + // The first write should pre-complete since it is <= _maxBytesPreCompleted. + Assert.Equal(TaskStatus.RanToCompletion, writeTask1.Status); + await _mockLibuv.OnPostTask; + Assert.NotEmpty(completeQueue); + + // Add more bytes to the write-behind buffer to prevent the next write from + outputProducer.Write((writableBuffer, state) => + { + writableBuffer.Write(state); + return state.Count; + }, + halfWriteBehindBuffer); + + // Act + var writeTask2 = outputProducer.WriteDataAsync(halfWriteBehindBuffer); + Assert.False(writeTask2.IsCompleted); + + var writeTask3 = outputProducer.WriteDataAsync(halfWriteBehindBuffer); + Assert.False(writeTask3.IsCompleted); + + // Drain the write queue + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + + var timeout = TestConstants.DefaultTimeout; + + await writeTask2.TimeoutAfter(timeout); + await writeTask3.TimeoutAfter(timeout); + + Assert.Empty(completeQueue); + } + }); + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task FailedWriteCompletesOrCancelsAllPendingTasks(int maxResponseBufferSize) + { + await Task.Run(async () => + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var abortedSource = new CancellationTokenSource(); + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions, abortedSource)) + { + var bufferSize = maxResponseBufferSize - 1; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + // Act + var task1Success = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: abortedSource.Token); + // task1 should complete successfully as < _maxBytesPreCompleted + + // First task is completed and successful + Assert.True(task1Success.IsCompleted); + Assert.False(task1Success.IsCanceled); + Assert.False(task1Success.IsFaulted); + + // following tasks should wait. + var task2Success = outputProducer.WriteDataAsync(fullBuffer); + var task3Canceled = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: abortedSource.Token); + + // Give time for tasks to percolate + await _mockLibuv.OnPostTask; + + // Second task is not completed + Assert.False(task2Success.IsCompleted); + Assert.False(task2Success.IsCanceled); + Assert.False(task2Success.IsFaulted); + + // Third task is not completed + Assert.False(task3Canceled.IsCompleted); + Assert.False(task3Canceled.IsCanceled); + Assert.False(task3Canceled.IsFaulted); + + // Cause all writes to fail + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(-1), triggerNextCompleted); + } + + // Second task is now completed + Assert.True(task2Success.IsCompleted); + Assert.False(task2Success.IsCanceled); + Assert.False(task2Success.IsFaulted); + + // A final write guarantees that the error is observed by OutputProducer, + // but doesn't return a canceled/faulted task. + var task4Success = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: default(CancellationToken)); + Assert.True(task4Success.IsCompleted); + Assert.False(task4Success.IsCanceled); + Assert.False(task4Success.IsFaulted); + + // Third task is now canceled + await Assert.ThrowsAsync(() => task3Canceled); + Assert.True(task3Canceled.IsCanceled); + + Assert.True(abortedSource.IsCancellationRequested); + + await _mockLibuv.OnPostTask; + + // Complete the 4th write + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + }); + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task CancelsBeforeWriteRequestCompletes(int maxResponseBufferSize) + { + await Task.Run(async () => + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var abortedSource = new CancellationTokenSource(); + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = maxResponseBufferSize - 1; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + // Act + var task1Success = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: abortedSource.Token); + // task1 should complete successfully as < _maxBytesPreCompleted + + // First task is completed and successful + Assert.True(task1Success.IsCompleted); + Assert.False(task1Success.IsCanceled); + Assert.False(task1Success.IsFaulted); + + // following tasks should wait. + var task3Canceled = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: abortedSource.Token); + + // Give time for tasks to percolate + await _mockLibuv.OnPostTask; + + // Third task is not completed + Assert.False(task3Canceled.IsCompleted); + Assert.False(task3Canceled.IsCanceled); + Assert.False(task3Canceled.IsFaulted); + + abortedSource.Cancel(); + + // Complete writes + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + + // A final write guarantees that the error is observed by OutputProducer, + // but doesn't return a canceled/faulted task. + var task4Success = outputProducer.WriteDataAsync(fullBuffer); + Assert.True(task4Success.IsCompleted); + Assert.False(task4Success.IsCanceled); + Assert.False(task4Success.IsFaulted); + + // Third task is now canceled + await Assert.ThrowsAsync(() => task3Canceled); + Assert.True(task3Canceled.IsCanceled); + + Assert.True(abortedSource.IsCancellationRequested); + + await _mockLibuv.OnPostTask; + + // Complete the 4th write + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + }); + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task WriteAsyncWithTokenAfterCallWithoutIsCancelled(int maxResponseBufferSize) + { + await Task.Run(async () => + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var abortedSource = new CancellationTokenSource(); + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = maxResponseBufferSize; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + // Act + var task1Waits = outputProducer.WriteDataAsync(fullBuffer); + + // First task is not completed + Assert.False(task1Waits.IsCompleted); + Assert.False(task1Waits.IsCanceled); + Assert.False(task1Waits.IsFaulted); + + // following tasks should wait. + var task3Canceled = outputProducer.WriteDataAsync(fullBuffer, cancellationToken: abortedSource.Token); + + // Give time for tasks to percolate + await _mockLibuv.OnPostTask; + + // Third task is not completed + Assert.False(task3Canceled.IsCompleted); + Assert.False(task3Canceled.IsCanceled); + Assert.False(task3Canceled.IsFaulted); + + abortedSource.Cancel(); + + // Complete writes + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + + // First task is completed + Assert.True(task1Waits.IsCompleted); + Assert.False(task1Waits.IsCanceled); + Assert.False(task1Waits.IsFaulted); + + // A final write guarantees that the error is observed by OutputProducer, + // but doesn't return a canceled/faulted task. + var task4Success = outputProducer.WriteDataAsync(fullBuffer); + Assert.True(task4Success.IsCompleted); + Assert.False(task4Success.IsCanceled); + Assert.False(task4Success.IsFaulted); + + // Third task is now canceled + await Assert.ThrowsAsync(() => task3Canceled); + Assert.True(task3Canceled.IsCanceled); + + await _mockLibuv.OnPostTask; + + // Complete the 4th write + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + } + }); + } + + [Theory] + [MemberData(nameof(PositiveMaxResponseBufferSizeData))] + public async Task WritesDontGetCompletedTooQuickly(int maxResponseBufferSize) + { + var completeQueue = new ConcurrentQueue>(); + + // Arrange + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + }; + + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize, + resumeWriterThreshold: maxResponseBufferSize, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + var bufferSize = maxResponseBufferSize - 1; + var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); + + // Act (Pre-complete the maximum number of bytes in preparation for the rest of the test) + var writeTask1 = outputProducer.WriteDataAsync(buffer); + + // Assert + // The first write should pre-complete since it is < _maxBytesPreCompleted. + await _mockLibuv.OnPostTask; + Assert.Equal(TaskStatus.RanToCompletion, writeTask1.Status); + Assert.NotEmpty(completeQueue); + + // Act + var writeTask2 = outputProducer.WriteDataAsync(buffer); + var writeTask3 = outputProducer.WriteDataAsync(buffer); + + await _mockLibuv.OnPostTask; + + // Drain the write queue + while (completeQueue.TryDequeue(out var triggerNextCompleted)) + { + await _libuvThread.PostAsync(cb => cb(0), triggerNextCompleted); + } + + var timeout = TestConstants.DefaultTimeout; + + // Assert + // Too many bytes are already pre-completed for the third but not the second write to pre-complete. + // https://github.com/aspnet/KestrelHttpServer/issues/356 + await writeTask2.TimeoutAfter(timeout); + await writeTask3.TimeoutAfter(timeout); + } + } + + [Theory] + [MemberData(nameof(MaxResponseBufferSizeData))] + public async Task WritesAreAggregated(long? maxResponseBufferSize) + { + var writeCalled = false; + var writeCount = 0; + + _mockLibuv.OnWrite = (socket, buffers, triggerCompleted) => + { + writeCount++; + triggerCompleted(0); + writeCalled = true; + return 0; + }; + + // ConnectionHandler will set Pause/ResumeWriterThreshold to zero when MaxResponseBufferSize is null. + // This is verified in PipeOptionsTests.OutputPipeOptionsConfiguredCorrectly. + var pipeOptions = new PipeOptions + ( + pool: _memoryPool, + readerScheduler: _libuvThread, + writerScheduler: PipeScheduler.Inline, + pauseWriterThreshold: maxResponseBufferSize ?? 0, + resumeWriterThreshold: maxResponseBufferSize ?? 0, + useSynchronizationContext: false + ); + + using (var outputProducer = CreateOutputProducer(pipeOptions)) + { + _mockLibuv.KestrelThreadBlocker.Reset(); + + var buffer = new ArraySegment(new byte[1]); + + // Two calls to WriteAsync trigger uv_write once if both calls + // are made before write is scheduled + var ignore = outputProducer.WriteDataAsync(buffer); + ignore = outputProducer.WriteDataAsync(buffer); + + _mockLibuv.KestrelThreadBlocker.Set(); + + await _mockLibuv.OnPostTask; + + Assert.True(writeCalled); + writeCalled = false; + + // Write isn't called twice after the thread is unblocked + await _mockLibuv.OnPostTask; + + Assert.False(writeCalled); + // One call to ScheduleWrite + Assert.Equal(1, _mockLibuv.PostCount); + // One call to uv_write + Assert.Equal(1, writeCount); + } + } + + private Http1OutputProducer CreateOutputProducer(PipeOptions pipeOptions, CancellationTokenSource cts = null) + { + var pair = DuplexPipe.CreateConnectionPair(pipeOptions, pipeOptions); + + var logger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext + { + Log = new TestKestrelTrace(logger), + Scheduler = PipeScheduler.Inline + }; + var transportContext = new TestLibuvTransportContext { Log = new LibuvTrace(logger) }; + + var socket = new MockSocket(_mockLibuv, _libuvThread.Loop.ThreadId, transportContext.Log); + var consumer = new LibuvOutputConsumer(pair.Application.Input, _libuvThread, socket, "0", transportContext.Log); + + var connectionFeatures = new FeatureCollection(); + connectionFeatures.Set(Mock.Of()); + connectionFeatures.Set(Mock.Of()); + + var http1Connection = new Http1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionContext = Mock.Of(), + ConnectionFeatures = connectionFeatures, + MemoryPool = _memoryPool, + TimeoutControl = Mock.Of(), + Application = pair.Application, + Transport = pair.Transport + }); + + if (cts != null) + { + http1Connection.RequestAborted.Register(cts.Cancel); + } + + var ignore = WriteOutputAsync(consumer, pair.Application.Input, http1Connection); + + return (Http1OutputProducer)http1Connection.Output; + } + + private async Task WriteOutputAsync(LibuvOutputConsumer consumer, PipeReader outputReader, Http1Connection http1Connection) + { + // This WriteOutputAsync() calling code is equivalent to that in LibuvConnection. + try + { + // Ensure that outputReader.Complete() runs on the LibuvThread. + // Without ConfigureAwait(false), xunit will dispatch. + await consumer.WriteOutputAsync().ConfigureAwait(false); + + http1Connection.Abort(abortReason: null); + outputReader.Complete(); + } + catch (UvException ex) + { + http1Connection.Abort(new ConnectionAbortedException(ex.Message, ex)); + outputReader.Complete(ex); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvThreadTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvThreadTests.cs new file mode 100644 index 0000000000..5177420d9c --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvThreadTests.cs @@ -0,0 +1,77 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvThreadTests + { + [Fact] + public async Task LibuvThreadDoesNotThrowIfPostingWorkAfterDispose() + { + var mockConnectionDispatcher = new MockConnectionDispatcher(); + var mockLibuv = new MockLibuv(); + var transportContext = new TestLibuvTransportContext() { ConnectionDispatcher = mockConnectionDispatcher }; + var transport = new LibuvTransport(mockLibuv, transportContext, null); + var thread = new LibuvThread(transport); + var ranOne = false; + var ranTwo = false; + var ranThree = false; + var ranFour = false; + + await thread.StartAsync(); + + await thread.PostAsync(_ => + { + ranOne = true; + }, + null); + + Assert.Equal(1, mockLibuv.PostCount); + + // Shutdown the libuv thread + await thread.StopAsync(TimeSpan.FromSeconds(5)); + + Assert.Equal(2, mockLibuv.PostCount); + + var task = thread.PostAsync(_ => + { + ranTwo = true; + }, + null); + + Assert.Equal(2, mockLibuv.PostCount); + + thread.Post(_ => + { + ranThree = true; + }, + null); + + Assert.Equal(2, mockLibuv.PostCount); + + thread.Schedule(_ => + { + ranFour = true; + }, + (object)null); + + Assert.Equal(2, mockLibuv.PostCount); + + Assert.True(task.IsCompleted); + Assert.True(ranOne); + Assert.False(ranTwo); + Assert.False(ranThree); + Assert.False(ranFour); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportFactoryTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportFactoryTests.cs new file mode 100644 index 0000000000..2db383f941 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportFactoryTests.cs @@ -0,0 +1,37 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvTransportFactoryTests + { + [Theory] + [InlineData(0)] + [InlineData(-1337)] + public void StartWithNonPositiveThreadCountThrows(int threadCount) + { + var options = new LibuvTransportOptions() { ThreadCount = threadCount }; + + var exception = Assert.Throws(() => + new LibuvTransportFactory(Options.Create(options), new LifetimeNotImplemented(), Mock.Of())); + + Assert.Equal("threadCount", exception.ParamName); + } + + [Fact] + public void LoggerCategoryNameIsLibuvTransportNamespace() + { + var mockLoggerFactory = new Mock(); + new LibuvTransportFactory(Options.Create(new LibuvTransportOptions()), new LifetimeNotImplemented(), mockLoggerFactory.Object); + mockLoggerFactory.Verify(factory => factory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv")); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportOptionsTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportOptionsTests.cs new file mode 100644 index 0000000000..8651bdb2ff --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportOptionsTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvTransportOptionsTests + { + [Fact] + public void SetThreadCountUsingProcessorCount() + { + // Ideally we'd mock Environment.ProcessorCount to test edge cases. + var expected = Clamp(Environment.ProcessorCount >> 1, 1, 16); + + var information = new LibuvTransportOptions(); + + Assert.Equal(expected, information.ThreadCount); + } + + private static int Clamp(int value, int min, int max) + { + return value < min ? min : value > max ? max : value; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs new file mode 100644 index 0000000000..4d2be2380c --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs @@ -0,0 +1,142 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class LibuvTransportTests + { + public static TheoryData ConnectionAdapterData => new TheoryData + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + } + }; + + public static IEnumerable OneToTen => Enumerable.Range(1, 10).Select(i => new object[] { i }); + + [Fact] + public async Task TransportCanBindAndStop() + { + var transportContext = new TestLibuvTransportContext(); + var transport = new LibuvTransport(transportContext, + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))); + + // The transport can no longer start threads without binding to an endpoint. + await transport.BindAsync(); + await transport.StopAsync(); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task TransportCanBindUnbindAndStop(ListenOptions listenOptions) + { + var transportContext = new TestLibuvTransportContext(); + var transport = new LibuvTransport(transportContext, listenOptions); + + await transport.BindAsync(); + await transport.UnbindAsync(); + await transport.StopAsync(); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionCanReadAndWrite(ListenOptions listenOptions) + { + var serviceContext = new TestServiceContext(); + listenOptions.UseHttpServer(listenOptions.ConnectionAdapters, serviceContext, new DummyApplication(TestApp.EchoApp), HttpProtocols.Http1); + + var transportContext = new TestLibuvTransportContext() + { + ConnectionDispatcher = new ConnectionDispatcher(serviceContext, listenOptions.Build()) + }; + + var transport = new LibuvTransport(transportContext, listenOptions); + + await transport.BindAsync(); + + using (var socket = TestConnection.CreateConnectedLoopbackSocket(listenOptions.IPEndPoint.Port)) + { + var data = "Hello World"; + socket.Send(Encoding.ASCII.GetBytes($"POST / HTTP/1.0\r\nContent-Length: 11\r\n\r\n{data}")); + var buffer = new byte[data.Length]; + var read = 0; + while (read < data.Length) + { + read += socket.Receive(buffer, read, buffer.Length - read, SocketFlags.None); + } + } + + await transport.UnbindAsync(); + await transport.StopAsync(); + } + + [ConditionalTheory] + [MemberData(nameof(OneToTen))] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "Tests fail on OS X due to low file descriptor limit.")] + public async Task OneToTenThreads(int threadCount) + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + var serviceContext = new TestServiceContext(); + var testApplication = new DummyApplication(context => + { + return context.Response.WriteAsync("Hello World"); + }); + + listenOptions.UseHttpServer(listenOptions.ConnectionAdapters, serviceContext, testApplication, HttpProtocols.Http1); + + var transportContext = new TestLibuvTransportContext() + { + ConnectionDispatcher = new ConnectionDispatcher(serviceContext, listenOptions.Build()), + Options = new LibuvTransportOptions { ThreadCount = threadCount } + }; + + var transport = new LibuvTransport(transportContext, listenOptions); + + await transport.BindAsync(); + + using (var client = new HttpClient()) + { + // Send 20 requests just to make sure we don't get any failures + var requestTasks = new List>(); + for (int i = 0; i < 20; i++) + { + var requestTask = client.GetStringAsync($"http://127.0.0.1:{listenOptions.IPEndPoint.Port}/"); + requestTasks.Add(requestTask); + } + + foreach (var result in await Task.WhenAll(requestTasks)) + { + Assert.Equal("Hello World", result); + } + } + + await transport.UnbindAsync(); + + if (!await serviceContext.ConnectionManager.CloseAllConnectionsAsync(default).ConfigureAwait(false)) + { + await serviceContext.ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false); + } + + await transport.StopAsync(); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/ListenerPrimaryTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/ListenerPrimaryTests.cs new file mode 100644 index 0000000000..2d906fc2cf --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/ListenerPrimaryTests.cs @@ -0,0 +1,341 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class ListenerPrimaryTests + { + [Fact] + public async Task ConnectionsGetRoundRobinedToSecondaryListeners() + { + var libuv = new LibuvFunctions(); + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + var serviceContextPrimary = new TestServiceContext(); + var transportContextPrimary = new TestLibuvTransportContext(); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary")), HttpProtocols.Http1); + transportContextPrimary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextPrimary, builderPrimary.Build()); + + var serviceContextSecondary = new TestServiceContext(); + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary")), HttpProtocols.Http1); + var transportContextSecondary = new TestLibuvTransportContext(); + transportContextSecondary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextSecondary, builderSecondary.Build()); + + var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); + + var pipeName = (libuv.IsWindows ? @"\\.\pipe\kestrel_" : "/tmp/kestrel_") + Guid.NewGuid().ToString("n"); + var pipeMessage = Guid.NewGuid().ToByteArray(); + + // Start primary listener + var libuvThreadPrimary = new LibuvThread(libuvTransport); + await libuvThreadPrimary.StartAsync(); + var listenerPrimary = new ListenerPrimary(transportContextPrimary); + await listenerPrimary.StartAsync(pipeName, pipeMessage, listenOptions, libuvThreadPrimary); + var address = GetUri(listenOptions); + + // Until a secondary listener is added, TCP connections get dispatched directly + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + + var listenerCount = listenerPrimary.UvPipeCount; + // Add secondary listener + var libuvThreadSecondary = new LibuvThread(libuvTransport); + await libuvThreadSecondary.StartAsync(); + var listenerSecondary = new ListenerSecondary(transportContextSecondary); + await listenerSecondary.StartAsync(pipeName, pipeMessage, listenOptions, libuvThreadSecondary); + + var maxWait = Task.Delay(TestConstants.DefaultTimeout); + // wait for ListenerPrimary.ReadCallback to add the secondary pipe + while (listenerPrimary.UvPipeCount == listenerCount) + { + var completed = await Task.WhenAny(maxWait, Task.Delay(100)); + if (ReferenceEquals(completed, maxWait)) + { + throw new TimeoutException("Timed out waiting for secondary listener to become available"); + } + } + + // Once a secondary listener is added, TCP connections start getting dispatched to it + await AssertResponseEventually(address, "Secondary", allowed: new[] { "Primary" }); + + // TCP connections will still get round-robined to the primary listener + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Secondary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + + await listenerSecondary.DisposeAsync(); + await libuvThreadSecondary.StopAsync(TimeSpan.FromSeconds(5)); + + await listenerPrimary.DisposeAsync(); + await libuvThreadPrimary.StopAsync(TimeSpan.FromSeconds(5)); + } + + // https://github.com/aspnet/KestrelHttpServer/issues/1182 + [Fact] + public async Task NonListenerPipeConnectionsAreLoggedAndIgnored() + { + var libuv = new LibuvFunctions(); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + var logger = new TestApplicationErrorLogger(); + + var serviceContextPrimary = new TestServiceContext(); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary")), HttpProtocols.Http1); + var transportContextPrimary = new TestLibuvTransportContext() { Log = new LibuvTrace(logger) }; + transportContextPrimary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextPrimary, builderPrimary.Build()); + + var serviceContextSecondary = new TestServiceContext + { + DateHeaderValueManager = serviceContextPrimary.DateHeaderValueManager, + ServerOptions = serviceContextPrimary.ServerOptions, + Scheduler = serviceContextPrimary.Scheduler, + HttpParser = serviceContextPrimary.HttpParser, + }; + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary")), HttpProtocols.Http1); + var transportContextSecondary = new TestLibuvTransportContext(); + transportContextSecondary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextSecondary, builderSecondary.Build()); + + var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); + + var pipeName = (libuv.IsWindows ? @"\\.\pipe\kestrel_" : "/tmp/kestrel_") + Guid.NewGuid().ToString("n"); + var pipeMessage = Guid.NewGuid().ToByteArray(); + + // Start primary listener + var libuvThreadPrimary = new LibuvThread(libuvTransport); + await libuvThreadPrimary.StartAsync(); + var listenerPrimary = new ListenerPrimary(transportContextPrimary); + await listenerPrimary.StartAsync(pipeName, pipeMessage, listenOptions, libuvThreadPrimary); + var address = GetUri(listenOptions); + + // Add secondary listener + var libuvThreadSecondary = new LibuvThread(libuvTransport); + await libuvThreadSecondary.StartAsync(); + var listenerSecondary = new ListenerSecondary(transportContextSecondary); + await listenerSecondary.StartAsync(pipeName, pipeMessage, listenOptions, libuvThreadSecondary); + + // TCP Connections get round-robined + await AssertResponseEventually(address, "Secondary", allowed: new[] { "Primary" }); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + + // Create a pipe connection and keep it open without sending any data + var connectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var connectionTrace = new LibuvTrace(new TestApplicationErrorLogger()); + var pipe = new UvPipeHandle(connectionTrace); + + libuvThreadPrimary.Post(_ => + { + var connectReq = new UvConnectRequest(connectionTrace); + + pipe.Init(libuvThreadPrimary.Loop, libuvThreadPrimary.QueueCloseHandle); + connectReq.Init(libuvThreadPrimary); + + connectReq.Connect( + pipe, + pipeName, + (req, status, ex, __) => + { + req.Dispose(); + + if (ex != null) + { + connectTcs.SetException(ex); + } + else + { + connectTcs.SetResult(null); + } + }, + null); + }, (object)null); + + await connectTcs.Task; + + // TCP connections will still get round-robined between only the two listeners + Assert.Equal("Secondary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Secondary", await HttpClientSlim.GetStringAsync(address)); + + await libuvThreadPrimary.PostAsync(_ => pipe.Dispose(), (object)null); + + // Wait up to 10 seconds for error to be logged + for (var i = 0; i < 10 && logger.TotalErrorsLogged == 0; i++) + { + await Task.Delay(100); + } + + // Same for after the non-listener pipe connection is closed + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Secondary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + + await listenerSecondary.DisposeAsync(); + await libuvThreadSecondary.StopAsync(TimeSpan.FromSeconds(5)); + + await listenerPrimary.DisposeAsync(); + await libuvThreadPrimary.StopAsync(TimeSpan.FromSeconds(5)); + + Assert.Equal(1, logger.TotalErrorsLogged); + var errorMessage = logger.Messages.First(m => m.LogLevel == LogLevel.Error); + Assert.Equal(TestConstants.EOF, Assert.IsType(errorMessage.Exception).StatusCode); + } + + + [Fact] + public async Task PipeConnectionsWithWrongMessageAreLoggedAndIgnored() + { + var libuv = new LibuvFunctions(); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + var logger = new TestApplicationErrorLogger(); + + var serviceContextPrimary = new TestServiceContext(); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary")), HttpProtocols.Http1); + var transportContextPrimary = new TestLibuvTransportContext() { Log = new LibuvTrace(logger) }; + transportContextPrimary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextPrimary, builderPrimary.Build()); + + var serviceContextSecondary = new TestServiceContext + { + DateHeaderValueManager = serviceContextPrimary.DateHeaderValueManager, + ServerOptions = serviceContextPrimary.ServerOptions, + Scheduler = serviceContextPrimary.Scheduler, + HttpParser = serviceContextPrimary.HttpParser, + }; + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary")), HttpProtocols.Http1); + var transportContextSecondary = new TestLibuvTransportContext(); + transportContextSecondary.ConnectionDispatcher = new ConnectionDispatcher(serviceContextSecondary, builderSecondary.Build()); + + var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); + + var pipeName = (libuv.IsWindows ? @"\\.\pipe\kestrel_" : "/tmp/kestrel_") + Guid.NewGuid().ToString("n"); + var pipeMessage = Guid.NewGuid().ToByteArray(); + + // Start primary listener + var libuvThreadPrimary = new LibuvThread(libuvTransport); + await libuvThreadPrimary.StartAsync(); + var listenerPrimary = new ListenerPrimary(transportContextPrimary); + await listenerPrimary.StartAsync(pipeName, pipeMessage, listenOptions, libuvThreadPrimary); + var address = GetUri(listenOptions); + + // Add secondary listener with wrong pipe message + var libuvThreadSecondary = new LibuvThread(libuvTransport); + await libuvThreadSecondary.StartAsync(); + var listenerSecondary = new ListenerSecondary(transportContextSecondary); + await listenerSecondary.StartAsync(pipeName, Guid.NewGuid().ToByteArray(), listenOptions, libuvThreadSecondary); + + // Wait up to 10 seconds for error to be logged + for (var i = 0; i < 10 && logger.TotalErrorsLogged == 0; i++) + { + await Task.Delay(100); + } + + // TCP Connections don't get round-robined + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + Assert.Equal("Primary", await HttpClientSlim.GetStringAsync(address)); + + await listenerSecondary.DisposeAsync(); + await libuvThreadSecondary.StopAsync(TimeSpan.FromSeconds(5)); + + await listenerPrimary.DisposeAsync(); + await libuvThreadPrimary.StopAsync(TimeSpan.FromSeconds(5)); + + Assert.Equal(1, logger.TotalErrorsLogged); + var errorMessage = logger.Messages.First(m => m.LogLevel == LogLevel.Error); + Assert.IsType(errorMessage.Exception); + Assert.Contains("Bad data", errorMessage.Exception.ToString()); + } + + private static async Task AssertResponseEventually( + Uri address, + string expected, + string[] allowed = null, + int maxRetries = 100, + int retryDelay = 100) + { + for (var i = 0; i < maxRetries; i++) + { + var response = await HttpClientSlim.GetStringAsync(address); + if (response == expected) + { + return; + } + + if (allowed != null) + { + Assert.Contains(response, allowed); + } + + await Task.Delay(retryDelay); + } + + Assert.True(false, $"'{address}' failed to respond with '{expected}' in {maxRetries} retries."); + } + + private static Uri GetUri(ListenOptions options) + { + if (options.Type != ListenType.IPEndPoint) + { + throw new InvalidOperationException($"Could not determine a proper URI for options with Type {options.Type}"); + } + + var scheme = options.ConnectionAdapters.Any(f => f.IsHttps) + ? "https" + : "http"; + + return new Uri($"{scheme}://{options.IPEndPoint}"); + } + + private class ConnectionBuilder : IConnectionBuilder + { + private readonly List> _components = new List>(); + + public IServiceProvider ApplicationServices { get; set; } + + public IConnectionBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public ConnectionDelegate Build() + { + ConnectionDelegate app = context => + { + return Task.CompletedTask; + }; + + for (int i = _components.Count - 1; i >= 0; i--) + { + var component = _components[i]; + app = component(app); + } + + return app; + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.csproj b/src/Servers/Kestrel/Transport.Libuv/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.csproj new file mode 100644 index 0000000000..4aa1f389e4 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.csproj @@ -0,0 +1,20 @@ + + + + netcoreapp2.1;net461 + true + true + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Transport.Libuv/test/MultipleLoopTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/MultipleLoopTests.cs new file mode 100644 index 0000000000..463165ef88 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/MultipleLoopTests.cs @@ -0,0 +1,251 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class MultipleLoopTests + { + private readonly LibuvFunctions _uv = new LibuvFunctions(); + private readonly ILibuvTrace _logger = new LibuvTrace(new TestApplicationErrorLogger()); + + [Fact] + public void InitAndCloseServerPipe() + { + var loop = new UvLoopHandle(_logger); + var pipe = new UvPipeHandle(_logger); + + loop.Init(_uv); + pipe.Init(loop, (a, b) => { }, true); + pipe.Bind(@"\\.\pipe\InitAndCloseServerPipe"); + pipe.Dispose(); + + loop.Run(); + + pipe.Dispose(); + loop.Dispose(); + } + + [Fact] + public void ServerPipeListenForConnections() + { + const string pipeName = @"\\.\pipe\ServerPipeListenForConnections"; + + var loop = new UvLoopHandle(_logger); + var serverListenPipe = new UvPipeHandle(_logger); + + loop.Init(_uv); + serverListenPipe.Init(loop, (a, b) => { }, false); + serverListenPipe.Bind(pipeName); + serverListenPipe.Listen(128, async (backlog, status, error, state) => + { + var serverConnectionPipe = new UvPipeHandle(_logger); + serverConnectionPipe.Init(loop, (a, b) => { }, true); + + try + { + serverListenPipe.Accept(serverConnectionPipe); + } + catch (Exception) + { + serverConnectionPipe.Dispose(); + return; + } + + var writeRequest = new UvWriteReq(_logger); + writeRequest.DangerousInit(loop); + + await writeRequest.WriteAsync( + serverConnectionPipe, + new ReadOnlySequence(new byte[] { 1, 2, 3, 4 })); + + writeRequest.Dispose(); + serverConnectionPipe.Dispose(); + serverListenPipe.Dispose(); + + }, null); + + var worker = new Thread(() => + { + var loop2 = new UvLoopHandle(_logger); + var clientConnectionPipe = new UvPipeHandle(_logger); + var connect = new UvConnectRequest(_logger); + + loop2.Init(_uv); + clientConnectionPipe.Init(loop2, (a, b) => { }, true); + connect.DangerousInit(loop2); + connect.Connect(clientConnectionPipe, pipeName, (handle, status, error, state) => + { + var buf = loop2.Libuv.buf_init(Marshal.AllocHGlobal(8192), 8192); + connect.Dispose(); + + clientConnectionPipe.ReadStart( + (handle2, cb, state2) => buf, + (handle2, status2, state2) => + { + if (status2 == TestConstants.EOF) + { + clientConnectionPipe.Dispose(); + } + }, + null); + }, null); + loop2.Run(); + loop2.Dispose(); + }); + worker.Start(); + loop.Run(); + loop.Dispose(); + worker.Join(); + } + + [Fact] + public void ServerPipeDispatchConnections() + { + var pipeName = @"\\.\pipe\ServerPipeDispatchConnections" + Guid.NewGuid().ToString("n"); + + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + + var serverConnectionPipe = default(UvPipeHandle); + var serverConnectionPipeAcceptedEvent = new ManualResetEvent(false); + var serverConnectionTcpDisposedEvent = new ManualResetEvent(false); + + var serverListenPipe = new UvPipeHandle(_logger); + serverListenPipe.Init(loop, (a, b) => { }, false); + serverListenPipe.Bind(pipeName); + serverListenPipe.Listen(128, (handle, status, error, state) => + { + serverConnectionPipe = new UvPipeHandle(_logger); + serverConnectionPipe.Init(loop, (a, b) => { }, true); + + try + { + serverListenPipe.Accept(serverConnectionPipe); + serverConnectionPipeAcceptedEvent.Set(); + } + catch (Exception ex) + { + Console.WriteLine(ex); + serverConnectionPipe.Dispose(); + serverConnectionPipe = null; + } + }, null); + + var serverListenTcp = new UvTcpHandle(_logger); + serverListenTcp.Init(loop, (a, b) => { }); + var endPoint = new IPEndPoint(IPAddress.Loopback, 0); + serverListenTcp.Bind(endPoint); + var port = serverListenTcp.GetSockIPEndPoint().Port; + serverListenTcp.Listen(128, (handle, status, error, state) => + { + var serverConnectionTcp = new UvTcpHandle(_logger); + serverConnectionTcp.Init(loop, (a, b) => { }); + serverListenTcp.Accept(serverConnectionTcp); + + serverConnectionPipeAcceptedEvent.WaitOne(); + + var writeRequest = new UvWriteReq(_logger); + writeRequest.DangerousInit(loop); + writeRequest.Write2( + serverConnectionPipe, + new ArraySegment>(new ArraySegment[] { new ArraySegment(new byte[] { 1, 2, 3, 4 }) }), + serverConnectionTcp, + (handle2, status2, error2, state2) => + { + writeRequest.Dispose(); + serverConnectionTcp.Dispose(); + serverConnectionTcpDisposedEvent.Set(); + serverConnectionPipe.Dispose(); + serverListenPipe.Dispose(); + serverListenTcp.Dispose(); + }, + null); + }, null); + + var worker = new Thread(() => + { + var loop2 = new UvLoopHandle(_logger); + var clientConnectionPipe = new UvPipeHandle(_logger); + var connect = new UvConnectRequest(_logger); + + loop2.Init(_uv); + clientConnectionPipe.Init(loop2, (a, b) => { }, true); + connect.DangerousInit(loop2); + connect.Connect(clientConnectionPipe, pipeName, (handle, status, error, state) => + { + connect.Dispose(); + + var buf = loop2.Libuv.buf_init(Marshal.AllocHGlobal(64), 64); + + serverConnectionTcpDisposedEvent.WaitOne(); + + clientConnectionPipe.ReadStart( + (handle2, cb, state2) => buf, + (handle2, status2, state2) => + { + if (status2 == TestConstants.EOF) + { + clientConnectionPipe.Dispose(); + return; + } + + var clientConnectionTcp = new UvTcpHandle(_logger); + clientConnectionTcp.Init(loop2, (a, b) => { }); + clientConnectionPipe.Accept(clientConnectionTcp); + var buf2 = loop2.Libuv.buf_init(Marshal.AllocHGlobal(64), 64); + clientConnectionTcp.ReadStart( + (handle3, cb, state3) => buf2, + (handle3, status3, state3) => + { + if (status3 == TestConstants.EOF) + { + clientConnectionTcp.Dispose(); + } + }, + null); + }, + null); + }, null); + loop2.Run(); + loop2.Dispose(); + }); + + var worker2 = new Thread(() => + { + try + { + serverConnectionPipeAcceptedEvent.WaitOne(); + + var socket = TestConnection.CreateConnectedLoopbackSocket(port); + socket.Send(new byte[] { 6, 7, 8, 9 }); + socket.Shutdown(SocketShutdown.Send); + var cb = socket.Receive(new byte[64]); + socket.Dispose(); + } + catch (Exception ex) + { + Console.WriteLine(ex); + } + }); + + worker.Start(); + worker2.Start(); + + loop.Run(); + loop.Dispose(); + worker.Join(); + worker2.Join(); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/NetworkingTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/NetworkingTests.cs new file mode 100644 index 0000000000..0e698d5477 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/NetworkingTests.cs @@ -0,0 +1,194 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + /// + /// Summary description for NetworkingTests + /// + public class NetworkingTests + { + private readonly LibuvFunctions _uv = new LibuvFunctions(); + private readonly ILibuvTrace _logger = new LibuvTrace(new TestApplicationErrorLogger()); + + [Fact] + public void LoopCanBeInitAndClose() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + loop.Run(); + loop.Dispose(); + } + + [Fact] + public void AsyncCanBeSent() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + var trigger = new UvAsyncHandle(_logger); + var called = false; + trigger.Init(loop, () => + { + called = true; + trigger.Dispose(); + }, (a, b) => { }); + trigger.Send(); + loop.Run(); + loop.Dispose(); + Assert.True(called); + } + + [Fact] + public void SocketCanBeInitAndClose() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + var tcp = new UvTcpHandle(_logger); + tcp.Init(loop, (a, b) => { }); + var endPoint = new IPEndPoint(IPAddress.Loopback, 0); + tcp.Bind(endPoint); + tcp.Dispose(); + loop.Run(); + loop.Dispose(); + } + + [Fact] + public async Task SocketCanListenAndAccept() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + var tcp = new UvTcpHandle(_logger); + tcp.Init(loop, (a, b) => { }); + var endPoint = new IPEndPoint(IPAddress.Loopback, 0); + tcp.Bind(endPoint); + var port = tcp.GetSockIPEndPoint().Port; + tcp.Listen(10, (stream, status, error, state) => + { + var tcp2 = new UvTcpHandle(_logger); + tcp2.Init(loop, (a, b) => { }); + stream.Accept(tcp2); + tcp2.Dispose(); + stream.Dispose(); + }, null); + var t = Task.Run(() => + { + var socket = TestConnection.CreateConnectedLoopbackSocket(port); + socket.Dispose(); + }); + loop.Run(); + loop.Dispose(); + await t; + } + + [Fact] + public async Task SocketCanRead() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + var tcp = new UvTcpHandle(_logger); + tcp.Init(loop, (a, b) => { }); + var endPoint = new IPEndPoint(IPAddress.Loopback, 0); + tcp.Bind(endPoint); + var port = tcp.GetSockIPEndPoint().Port; + tcp.Listen(10, (_, status, error, state) => + { + var tcp2 = new UvTcpHandle(_logger); + tcp2.Init(loop, (a, b) => { }); + tcp.Accept(tcp2); + var data = Marshal.AllocCoTaskMem(500); + tcp2.ReadStart( + (a, b, c) => _uv.buf_init(data, 500), + (__, nread, state2) => + { + if (nread <= 0) + { + tcp2.Dispose(); + } + }, + null); + tcp.Dispose(); + }, null); + var t = Task.Run(async () => + { + var socket = TestConnection.CreateConnectedLoopbackSocket(port); + await socket.SendAsync(new[] { new ArraySegment(new byte[] { 1, 2, 3, 4, 5 }) }, + SocketFlags.None); + socket.Dispose(); + }); + loop.Run(); + loop.Dispose(); + await t; + } + + [Fact] + public async Task SocketCanReadAndWrite() + { + var loop = new UvLoopHandle(_logger); + loop.Init(_uv); + var tcp = new UvTcpHandle(_logger); + tcp.Init(loop, (a, b) => { }); + var endPoint = new IPEndPoint(IPAddress.Loopback, 0); + tcp.Bind(endPoint); + var port = tcp.GetSockIPEndPoint().Port; + tcp.Listen(10, (_, status, error, state) => + { + var tcp2 = new UvTcpHandle(_logger); + tcp2.Init(loop, (a, b) => { }); + tcp.Accept(tcp2); + var data = Marshal.AllocCoTaskMem(500); + tcp2.ReadStart( + (a, b, c) => tcp2.Libuv.buf_init(data, 500), + async (__, nread, state2) => + { + if (nread <= 0) + { + tcp2.Dispose(); + } + else + { + for (var x = 0; x < 2; x++) + { + var req = new UvWriteReq(_logger); + req.DangerousInit(loop); + var block = new ReadOnlySequence(new byte[] { 65, 66, 67, 68, 69 }); + + await req.WriteAsync( + tcp2, + block); + } + } + }, + null); + tcp.Dispose(); + }, null); + var t = Task.Run(async () => + { + var socket = TestConnection.CreateConnectedLoopbackSocket(port); + await socket.SendAsync(new[] { new ArraySegment(new byte[] { 1, 2, 3, 4, 5 }) }, + SocketFlags.None); + socket.Shutdown(SocketShutdown.Send); + var buffer = new ArraySegment(new byte[2048]); + while (true) + { + var count = await socket.ReceiveAsync(new[] { buffer }, SocketFlags.None); + if (count <= 0) break; + } + socket.Dispose(); + }); + loop.Run(); + loop.Dispose(); + await t; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockConnectionDispatcher.cs b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockConnectionDispatcher.cs new file mode 100644 index 0000000000..73cdfbb3a2 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockConnectionDispatcher.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers +{ + public class MockConnectionDispatcher : IConnectionDispatcher + { + public Func, PipeOptions> InputOptions { get; set; } = pool => new PipeOptions(pool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + public Func, PipeOptions> OutputOptions { get; set; } = pool => new PipeOptions(pool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + + public void OnConnection(TransportConnection connection) + { + Input = new Pipe(InputOptions(connection.MemoryPool)); + Output = new Pipe(OutputOptions(connection.MemoryPool)); + + connection.Transport = new DuplexPipe(Input.Reader, Output.Writer); + connection.Application = new DuplexPipe(Output.Reader, Input.Writer); + } + + public Pipe Input { get; private set; } + public Pipe Output { get; private set; } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockLibuv.cs b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockLibuv.cs new file mode 100644 index 0000000000..6cb3f6917a --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockLibuv.cs @@ -0,0 +1,165 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers +{ + public class MockLibuv : LibuvFunctions + { + private UvAsyncHandle _postHandle; + private uv_async_cb _onPost; + + private readonly object _postLock = new object(); + private TaskCompletionSource _onPostTcs = new TaskCompletionSource(); + private bool _completedOnPostTcs; + + private bool _stopLoop; + private readonly ManualResetEventSlim _loopWh = new ManualResetEventSlim(); + + private readonly string _stackTrace; + + unsafe public MockLibuv() + : base(onlyForTesting: true) + { + _stackTrace = Environment.StackTrace; + + OnWrite = (socket, buffers, triggerCompleted) => + { + triggerCompleted(0); + return 0; + }; + + _uv_write = UvWrite; + + _uv_async_send = postHandle => + { + lock (_postLock) + { + if (_completedOnPostTcs) + { + _onPostTcs = new TaskCompletionSource(); + _completedOnPostTcs = false; + } + + PostCount++; + + _loopWh.Set(); + } + + return 0; + }; + + _uv_async_init = (loop, postHandle, callback) => + { + _postHandle = postHandle; + _onPost = callback; + + return 0; + }; + + _uv_run = (loopHandle, mode) => + { + while (!_stopLoop) + { + _loopWh.Wait(); + KestrelThreadBlocker.Wait(); + + lock (_postLock) + { + _loopWh.Reset(); + } + + _onPost(_postHandle.InternalGetHandle()); + + lock (_postLock) + { + // Allow the loop to be run again before completing + // _onPostTcs given a nested uv_async_send call. + if (!_loopWh.IsSet) + { + // Ensure any subsequent calls to uv_async_send + // create a new _onPostTcs to be completed. + _completedOnPostTcs = true; + + // Calling TrySetResult outside the lock to avoid deadlock + // when the code attempts to call uv_async_send after awaiting + // OnPostTask. Task.Run so the run loop doesn't block either. + var onPostTcs = _onPostTcs; + Task.Run(() => onPostTcs.TrySetResult(null)); + } + } + } + + return 0; + }; + + _uv_ref = handle => { }; + _uv_unref = handle => + { + _stopLoop = true; + _loopWh.Set(); + }; + + _uv_stop = handle => + { + _stopLoop = true; + _loopWh.Set(); + }; + + _uv_req_size = reqType => IntPtr.Size; + _uv_loop_size = () => IntPtr.Size; + _uv_handle_size = handleType => IntPtr.Size; + _uv_loop_init = loop => 0; + _uv_tcp_init = (loopHandle, tcpHandle) => 0; + _uv_close = (handle, callback) => callback(handle); + _uv_loop_close = handle => 0; + _uv_walk = (loop, callback, ignore) => 0; + _uv_err_name = errno => IntPtr.Zero; + _uv_strerror = errno => IntPtr.Zero; + _uv_read_start = UvReadStart; + _uv_read_stop = (handle) => + { + AllocCallback = null; + ReadCallback = null; + return 0; + }; + _uv_unsafe_async_send = handle => + { + throw new Exception($"Why is this getting called?{Environment.NewLine}{_stackTrace}"); + }; + + _uv_timer_init = (loop, handle) => 0; + _uv_timer_start = (handle, callback, timeout, repeat) => 0; + _uv_timer_stop = handle => 0; + _uv_now = (loop) => DateTime.UtcNow.Ticks / TimeSpan.TicksPerMillisecond; + } + + public Func, int> OnWrite { get; set; } + + public uv_alloc_cb AllocCallback { get; set; } + + public uv_read_cb ReadCallback { get; set; } + + public int PostCount { get; set; } + + public Task OnPostTask => _onPostTcs.Task; + + public ManualResetEventSlim KestrelThreadBlocker { get; } = new ManualResetEventSlim(true); + + private int UvReadStart(UvStreamHandle handle, uv_alloc_cb allocCallback, uv_read_cb readCallback) + { + AllocCallback = allocCallback; + ReadCallback = readCallback; + return 0; + } + + unsafe private int UvWrite(UvRequest req, UvStreamHandle handle, uv_buf_t* bufs, int nbufs, uv_write_cb cb) + { + return OnWrite(handle, nbufs, status => cb(req.InternalGetHandle(), status)); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockSocket.cs b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockSocket.cs new file mode 100644 index 0000000000..110f8667ba --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/MockSocket.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers +{ + class MockSocket : UvStreamHandle + { + public MockSocket(LibuvFunctions uv, int threadId, ILibuvTrace logger) : base(logger) + { + CreateMemory(uv, threadId, IntPtr.Size); + } + + protected override bool ReleaseHandle() + { + DestroyMemory(handle); + handle = IntPtr.Zero; + return true; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/TestLibuvTransportContext.cs b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/TestLibuvTransportContext.cs new file mode 100644 index 0000000000..c779b87647 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/TestHelpers/TestLibuvTransportContext.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers +{ + public class TestLibuvTransportContext : LibuvTransportContext + { + public TestLibuvTransportContext() + { + var logger = new TestApplicationErrorLogger(); + + AppLifetime = new LifetimeNotImplemented(); + ConnectionDispatcher = new MockConnectionDispatcher(); + Log = new LibuvTrace(logger); + Options = new LibuvTransportOptions { ThreadCount = 1 }; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Libuv/test/UvStreamHandleTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/UvStreamHandleTests.cs new file mode 100644 index 0000000000..d52e120add --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/UvStreamHandleTests.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; +using Microsoft.AspNetCore.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class UvStreamHandleTests + { + [Fact] + public void ReadStopIsIdempotent() + { + var libuvTrace = new LibuvTrace(new TestApplicationErrorLogger()); + + using (var uvLoopHandle = new UvLoopHandle(libuvTrace)) + using (var uvTcpHandle = new UvTcpHandle(libuvTrace)) + { + uvLoopHandle.Init(new MockLibuv()); + uvTcpHandle.Init(uvLoopHandle, null); + + UvStreamHandle uvStreamHandle = uvTcpHandle; + uvStreamHandle.ReadStop(); + uvStreamHandle.ReadStop(); + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Libuv/test/UvTimerHandleTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/UvTimerHandleTests.cs new file mode 100644 index 0000000000..0eafb1a48b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Libuv/test/UvTimerHandleTests.cs @@ -0,0 +1,71 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; +using Microsoft.AspNetCore.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests +{ + public class UvTimerHandleTests + { + private readonly ILibuvTrace _trace = new LibuvTrace(new TestApplicationErrorLogger()); + + [Fact] + public void TestTimeout() + { + var loop = new UvLoopHandle(_trace); + loop.Init(new LibuvFunctions()); + + var timer = new UvTimerHandle(_trace); + timer.Init(loop, (a, b) => { }); + + var callbackInvoked = false; + timer.Start(_ => + { + callbackInvoked = true; + }, 1, 0); + loop.Run(); + + timer.Dispose(); + loop.Run(); + + loop.Dispose(); + + Assert.True(callbackInvoked); + } + + [Fact] + public void TestRepeat() + { + var loop = new UvLoopHandle(_trace); + loop.Init(new LibuvFunctions()); + + var timer = new UvTimerHandle(_trace); + timer.Init(loop, (callback, handle) => { }); + + var callbackCount = 0; + timer.Start(_ => + { + if (callbackCount < 2) + { + callbackCount++; + } + else + { + timer.Stop(); + } + }, 1, 1); + + loop.Run(); + + timer.Dispose(); + loop.Run(); + + loop.Dispose(); + + Assert.Equal(2, callbackCount); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/BufferExtensions.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/BufferExtensions.cs new file mode 100644 index 0000000000..9985dfbc5b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/BufferExtensions.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public static class BufferExtensions + { + public static ArraySegment GetArray(this Memory memory) + { + return ((ReadOnlyMemory)memory).GetArray(); + } + + public static ArraySegment GetArray(this ReadOnlyMemory memory) + { + if (!MemoryMarshal.TryGetArray(memory, out var result)) + { + throw new InvalidOperationException("Buffer backed by array was expected"); + } + return result; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/IOQueue.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/IOQueue.cs new file mode 100644 index 0000000000..6a3e60d1ac --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/IOQueue.cs @@ -0,0 +1,65 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.IO.Pipelines; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class IOQueue : PipeScheduler + { + private static readonly WaitCallback _doWorkCallback = s => ((IOQueue)s).DoWork(); + + private readonly object _workSync = new object(); + private readonly ConcurrentQueue _workItems = new ConcurrentQueue(); + private bool _doingWork; + + public override void Schedule(Action action, object state) + { + var work = new Work + { + Callback = action, + State = state + }; + + _workItems.Enqueue(work); + + lock (_workSync) + { + if (!_doingWork) + { + System.Threading.ThreadPool.QueueUserWorkItem(_doWorkCallback, this); + _doingWork = true; + } + } + } + + private void DoWork() + { + while (true) + { + while (_workItems.TryDequeue(out Work item)) + { + item.Callback(item.State); + } + + lock (_workSync) + { + if (_workItems.IsEmpty) + { + _doingWork = false; + return; + } + } + } + } + + private struct Work + { + public Action Callback; + public object State; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/ISocketsTrace.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/ISocketsTrace.cs new file mode 100644 index 0000000000..06c2c36f20 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/ISocketsTrace.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public interface ISocketsTrace : ILogger + { + void ConnectionReadFin(string connectionId); + + void ConnectionWriteFin(string connectionId); + + void ConnectionError(string connectionId, Exception ex); + + void ConnectionReset(string connectionId); + + void ConnectionPause(string connectionId); + + void ConnectionResume(string connectionId); + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketAwaitable.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketAwaitable.cs new file mode 100644 index 0000000000..6c4de75c45 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketAwaitable.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketAwaitable : ICriticalNotifyCompletion + { + private static readonly Action _callbackCompleted = () => { }; + + private readonly PipeScheduler _ioScheduler; + + private Action _callback; + private int _bytesTransferred; + private SocketError _error; + + public SocketAwaitable(PipeScheduler ioScheduler) + { + _ioScheduler = ioScheduler; + } + + public SocketAwaitable GetAwaiter() => this; + public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); + + public int GetResult() + { + Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); + + _callback = null; + + if (_error != SocketError.Success) + { + throw new SocketException((int)_error); + } + + return _bytesTransferred; + } + + public void OnCompleted(Action continuation) + { + if (ReferenceEquals(_callback, _callbackCompleted) || + ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) + { + Task.Run(continuation); + } + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + + public void Complete(int bytesTransferred, SocketError socketError) + { + _error = socketError; + _bytesTransferred = bytesTransferred; + var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); + + if (continuation != null) + { + _ioScheduler.Schedule(state => ((Action)state)(), continuation); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs new file mode 100644 index 0000000000..db99693f67 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs @@ -0,0 +1,322 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + internal sealed class SocketConnection : TransportConnection + { + private static readonly int MinAllocBufferSize = KestrelMemoryPool.MinimumSegmentSize / 2; + private static readonly bool IsWindows = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + + private readonly Socket _socket; + private readonly PipeScheduler _scheduler; + private readonly ISocketsTrace _trace; + private readonly SocketReceiver _receiver; + private readonly SocketSender _sender; + private readonly CancellationTokenSource _connectionClosedTokenSource = new CancellationTokenSource(); + + private readonly object _shutdownLock = new object(); + private volatile bool _aborted; + private volatile ConnectionAbortedException _abortReason; + private long _totalBytesWritten; + + internal SocketConnection(Socket socket, MemoryPool memoryPool, PipeScheduler scheduler, ISocketsTrace trace) + { + Debug.Assert(socket != null); + Debug.Assert(memoryPool != null); + Debug.Assert(trace != null); + + _socket = socket; + MemoryPool = memoryPool; + _scheduler = scheduler; + _trace = trace; + + var localEndPoint = (IPEndPoint)_socket.LocalEndPoint; + var remoteEndPoint = (IPEndPoint)_socket.RemoteEndPoint; + + LocalAddress = localEndPoint.Address; + LocalPort = localEndPoint.Port; + + RemoteAddress = remoteEndPoint.Address; + RemotePort = remoteEndPoint.Port; + + ConnectionClosed = _connectionClosedTokenSource.Token; + + // On *nix platforms, Sockets already dispatches to the ThreadPool. + // Yes, the IOQueues are still used for the PipeSchedulers. This is intentional. + // https://github.com/aspnet/KestrelHttpServer/issues/2573 + var awaiterScheduler = IsWindows ? _scheduler : PipeScheduler.Inline; + + _receiver = new SocketReceiver(_socket, awaiterScheduler); + _sender = new SocketSender(_socket, awaiterScheduler); + } + + public override MemoryPool MemoryPool { get; } + public override PipeScheduler InputWriterScheduler => _scheduler; + public override PipeScheduler OutputReaderScheduler => _scheduler; + public override long TotalBytesWritten => Interlocked.Read(ref _totalBytesWritten); + + public async Task StartAsync() + { + try + { + // Spawn send and receive logic + var receiveTask = DoReceive(); + var sendTask = DoSend(); + + // Now wait for both to complete + await receiveTask; + await sendTask; + + _receiver.Dispose(); + _sender.Dispose(); + ThreadPool.QueueUserWorkItem(state => ((SocketConnection)state).CancelConnectionClosedToken(), this); + } + catch (Exception ex) + { + _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}."); + } + } + + public override void Abort(ConnectionAbortedException abortReason) + { + _abortReason = abortReason; + Output.CancelPendingRead(); + + // Try to gracefully close the socket to match libuv behavior. + Shutdown(); + } + + private async Task DoReceive() + { + Exception error = null; + + try + { + await ProcessReceives(); + } + catch (SocketException ex) when (IsConnectionResetError(ex.SocketErrorCode)) + { + // A connection reset can be reported as SocketError.ConnectionAborted on Windows + if (!_aborted) + { + error = new ConnectionResetException(ex.Message, ex); + _trace.ConnectionReset(ConnectionId); + } + } + catch (SocketException ex) when (IsConnectionAbortError(ex.SocketErrorCode)) + { + if (!_aborted) + { + // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. + _trace.ConnectionError(ConnectionId, error); + } + } + catch (ObjectDisposedException) + { + if (!_aborted) + { + _trace.ConnectionError(ConnectionId, error); + } + } + catch (IOException ex) + { + error = ex; + _trace.ConnectionError(ConnectionId, error); + } + catch (Exception ex) + { + error = new IOException(ex.Message, ex); + _trace.ConnectionError(ConnectionId, error); + } + finally + { + if (_aborted) + { + error = error ?? _abortReason ?? new ConnectionAbortedException(); + } + + Input.Complete(error); + } + } + + private async Task ProcessReceives() + { + while (true) + { + // Ensure we have some reasonable amount of buffer space + var buffer = Input.GetMemory(MinAllocBufferSize); + + var bytesReceived = await _receiver.ReceiveAsync(buffer); + + if (bytesReceived == 0) + { + // FIN + _trace.ConnectionReadFin(ConnectionId); + break; + } + + Input.Advance(bytesReceived); + + var flushTask = Input.FlushAsync(); + + if (!flushTask.IsCompleted) + { + _trace.ConnectionPause(ConnectionId); + + await flushTask; + + _trace.ConnectionResume(ConnectionId); + } + + var result = flushTask.GetAwaiter().GetResult(); + if (result.IsCompleted) + { + // Pipe consumer is shut down, do we stop writing + break; + } + } + } + + private async Task DoSend() + { + Exception error = null; + + try + { + await ProcessSends(); + } + catch (SocketException ex) when (IsConnectionResetError(ex.SocketErrorCode)) + { + // A connection reset can be reported as SocketError.ConnectionAborted on Windows + error = null; + _trace.ConnectionReset(ConnectionId); + } + catch (SocketException ex) when (IsConnectionAbortError(ex.SocketErrorCode)) + { + error = null; + } + catch (ObjectDisposedException) + { + error = null; + } + catch (IOException ex) + { + error = ex; + _trace.ConnectionError(ConnectionId, error); + } + catch (Exception ex) + { + error = new IOException(ex.Message, ex); + _trace.ConnectionError(ConnectionId, error); + } + finally + { + Shutdown(); + + // Complete the output after disposing the socket + Output.Complete(error); + } + } + + private async Task ProcessSends() + { + while (true) + { + var result = await Output.ReadAsync(); + + var buffer = result.Buffer; + + if (result.IsCanceled) + { + break; + } + + var end = buffer.End; + var isCompleted = result.IsCompleted; + if (!buffer.IsEmpty) + { + await _sender.SendAsync(buffer); + } + + // This is not interlocked because there could be a concurrent writer. + // Instead it's to prevent read tearing on 32-bit systems. + Interlocked.Add(ref _totalBytesWritten, buffer.Length); + + Output.AdvanceTo(end); + + if (isCompleted) + { + break; + } + } + } + + private void Shutdown() + { + lock (_shutdownLock) + { + if (!_aborted) + { + // Make sure to close the connection only after the _aborted flag is set. + // Without this, the RequestsCanBeAbortedMidRead test will sometimes fail when + // a BadHttpRequestException is thrown instead of a TaskCanceledException. + _aborted = true; + _trace.ConnectionWriteFin(ConnectionId); + + try + { + // Try to gracefully close the socket even for aborts to match libuv behavior. + _socket.Shutdown(SocketShutdown.Both); + } + catch + { + // Ignore any errors from Socket.Shutdown since we're tearing down the connection anyway. + } + + _socket.Dispose(); + } + } + } + + private void CancelConnectionClosedToken() + { + try + { + _connectionClosedTokenSource.Cancel(); + } + catch (Exception ex) + { + _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(CancelConnectionClosedToken)}."); + } + } + + private static bool IsConnectionResetError(SocketError errorCode) + { + return errorCode == SocketError.ConnectionReset || + errorCode == SocketError.ConnectionAborted || + errorCode == SocketError.Shutdown; + } + + private static bool IsConnectionAbortError(SocketError errorCode) + { + return errorCode == SocketError.OperationAborted || + errorCode == SocketError.Interrupted || + errorCode == SocketError.InvalidArgument; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketReceiver.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketReceiver.cs new file mode 100644 index 0000000000..2116c03cd5 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketReceiver.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Net.Sockets; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketReceiver : IDisposable + { + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable; + + public SocketReceiver(Socket socket, PipeScheduler scheduler) + { + _socket = socket; + _awaitable = new SocketAwaitable(scheduler); + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); + } + + public SocketAwaitable ReceiveAsync(Memory buffer) + { +#if NETCOREAPP2_1 + _eventArgs.SetBuffer(buffer); +#else + var segment = buffer.GetArray(); + + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); +#endif + if (!_socket.ReceiveAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + } + + return _awaitable; + } + + public void Dispose() + { + _eventArgs.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketSender.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketSender.cs new file mode 100644 index 0000000000..26e22b664d --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketSender.cs @@ -0,0 +1,107 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Net.Sockets; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketSender : IDisposable + { + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable; + + private List> _bufferList; + + public SocketSender(Socket socket, PipeScheduler scheduler) + { + _socket = socket; + _awaitable = new SocketAwaitable(scheduler); + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); + } + + public SocketAwaitable SendAsync(ReadOnlySequence buffers) + { + if (buffers.IsSingleSegment) + { + return SendAsync(buffers.First); + } + +#if NETCOREAPP2_1 + if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) +#else + if (_eventArgs.Buffer != null) +#endif + { + _eventArgs.SetBuffer(null, 0, 0); + } + + _eventArgs.BufferList = GetBufferList(buffers); + + if (!_socket.SendAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + } + + return _awaitable; + } + + private SocketAwaitable SendAsync(ReadOnlyMemory memory) + { + // The BufferList getter is much less expensive then the setter. + if (_eventArgs.BufferList != null) + { + _eventArgs.BufferList = null; + } + +#if NETCOREAPP2_1 + _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); +#else + var segment = memory.GetArray(); + + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); +#endif + if (!_socket.SendAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + } + + return _awaitable; + } + + private List> GetBufferList(ReadOnlySequence buffer) + { + Debug.Assert(!buffer.IsEmpty); + Debug.Assert(!buffer.IsSingleSegment); + + if (_bufferList == null) + { + _bufferList = new List>(); + } + else + { + // Buffers are pooled, so it's OK to root them until the next multi-buffer write. + _bufferList.Clear(); + } + + foreach (var b in buffer) + { + _bufferList.Add(b.GetArray()); + } + + return _bufferList; + } + + public void Dispose() + { + _eventArgs.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketsTrace.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketsTrace.cs new file mode 100644 index 0000000000..e8aa5fa286 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketsTrace.cs @@ -0,0 +1,93 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketsTrace : ISocketsTrace + { + // ConnectionRead: Reserved: 3 + + private static readonly Action _connectionPause = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(ConnectionPause)), @"Connection id ""{ConnectionId}"" paused."); + + private static readonly Action _connectionResume = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ConnectionResume)), @"Connection id ""{ConnectionId}"" resumed."); + + private static readonly Action _connectionReadFin = + LoggerMessage.Define(LogLevel.Debug, new EventId(6, nameof(ConnectionReadFin)), @"Connection id ""{ConnectionId}"" received FIN."); + + private static readonly Action _connectionWriteFin = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, nameof(ConnectionWriteFin)), @"Connection id ""{ConnectionId}"" sending FIN."); + + private static readonly Action _connectionError = + LoggerMessage.Define(LogLevel.Information, new EventId(14, nameof(ConnectionError)), @"Connection id ""{ConnectionId}"" communication error."); + + private static readonly Action _connectionReset = + LoggerMessage.Define(LogLevel.Debug, new EventId(19, nameof(ConnectionReset)), @"Connection id ""{ConnectionId}"" reset."); + + private readonly ILogger _logger; + + public SocketsTrace(ILogger logger) + { + _logger = logger; + } + + public void ConnectionRead(string connectionId, int count) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 3 + } + + public void ConnectionReadFin(string connectionId) + { + _connectionReadFin(_logger, connectionId, null); + } + + public void ConnectionWriteFin(string connectionId) + { + _connectionWriteFin(_logger, connectionId, null); + } + + public void ConnectionWrite(string connectionId, int count) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 11 + } + + public void ConnectionWriteCallback(string connectionId, int status) + { + // Don't log for now since this could be *too* verbose. + // Reserved: Event ID 12 + } + + public void ConnectionError(string connectionId, Exception ex) + { + _connectionError(_logger, connectionId, ex); + } + + public void ConnectionReset(string connectionId) + { + _connectionReset(_logger, connectionId, null); + } + + public void ConnectionPause(string connectionId) + { + _connectionPause(_logger, connectionId, null); + } + + public void ConnectionResume(string connectionId) + { + _connectionResume(_logger, connectionId, null); + } + + public IDisposable BeginScope(TState state) => _logger.BeginScope(state); + + public bool IsEnabled(LogLevel logLevel) => _logger.IsEnabled(logLevel); + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + => _logger.Log(logLevel, eventId, state, exception, formatter); + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj b/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj new file mode 100644 index 0000000000..82dde8daa8 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj @@ -0,0 +1,24 @@ + + + + Managed socket transport for the ASP.NET Core Kestrel cross-platform web server. + netstandard2.0;netcoreapp2.1 + true + aspnetcore;kestrel + true + CS1591;$(NoWarn) + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Properties/SocketsStrings.Designer.cs b/src/Servers/Kestrel/Transport.Sockets/src/Properties/SocketsStrings.Designer.cs new file mode 100644 index 0000000000..2d26f8f398 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Properties/SocketsStrings.Designer.cs @@ -0,0 +1,58 @@ +// +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets +{ + using System.Globalization; + using System.Reflection; + using System.Resources; + + internal static class SocketsStrings + { + private static readonly ResourceManager _resourceManager + = new ResourceManager("Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketsStrings", typeof(SocketsStrings).GetTypeInfo().Assembly); + + /// + /// Only ListenType.IPEndPoint is supported by the Socket Transport. https://go.microsoft.com/fwlink/?linkid=874850 + /// + internal static string OnlyIPEndPointsSupported + { + get => GetString("OnlyIPEndPointsSupported"); + } + + /// + /// Only ListenType.IPEndPoint is supported by the Socket Transport. https://go.microsoft.com/fwlink/?linkid=874850 + /// + internal static string FormatOnlyIPEndPointsSupported() + => GetString("OnlyIPEndPointsSupported"); + + /// + /// Transport is already bound. + /// + internal static string TransportAlreadyBound + { + get => GetString("TransportAlreadyBound"); + } + + /// + /// Transport is already bound. + /// + internal static string FormatTransportAlreadyBound() + => GetString("TransportAlreadyBound"); + + private static string GetString(string name, params string[] formatterNames) + { + var value = _resourceManager.GetString(name); + + System.Diagnostics.Debug.Assert(value != null); + + if (formatterNames != null) + { + for (var i = 0; i < formatterNames.Length; i++) + { + value = value.Replace("{" + formatterNames[i] + "}", "{" + i + "}"); + } + } + + return value; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransport.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransport.cs new file mode 100644 index 0000000000..1cf2782c15 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransport.cs @@ -0,0 +1,230 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Net; +using System.Net.Sockets; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets +{ + internal sealed class SocketTransport : ITransport + { + private static readonly PipeScheduler[] ThreadPoolSchedulerArray = new PipeScheduler[] { PipeScheduler.ThreadPool }; + + private readonly MemoryPool _memoryPool = KestrelMemoryPool.Create(); + private readonly IEndPointInformation _endPointInformation; + private readonly IConnectionDispatcher _dispatcher; + private readonly IApplicationLifetime _appLifetime; + private readonly int _numSchedulers; + private readonly PipeScheduler[] _schedulers; + private readonly ISocketsTrace _trace; + private Socket _listenSocket; + private Task _listenTask; + private Exception _listenException; + private volatile bool _unbinding; + + internal SocketTransport( + IEndPointInformation endPointInformation, + IConnectionDispatcher dispatcher, + IApplicationLifetime applicationLifetime, + int ioQueueCount, + ISocketsTrace trace) + { + Debug.Assert(endPointInformation != null); + Debug.Assert(endPointInformation.Type == ListenType.IPEndPoint); + Debug.Assert(dispatcher != null); + Debug.Assert(applicationLifetime != null); + Debug.Assert(trace != null); + + _endPointInformation = endPointInformation; + _dispatcher = dispatcher; + _appLifetime = applicationLifetime; + _trace = trace; + + if (ioQueueCount > 0) + { + _numSchedulers = ioQueueCount; + _schedulers = new IOQueue[_numSchedulers]; + + for (var i = 0; i < _numSchedulers; i++) + { + _schedulers[i] = new IOQueue(); + } + } + else + { + _numSchedulers = ThreadPoolSchedulerArray.Length; + _schedulers = ThreadPoolSchedulerArray; + } + } + + public Task BindAsync() + { + if (_listenSocket != null) + { + throw new InvalidOperationException(SocketsStrings.TransportAlreadyBound); + } + + IPEndPoint endPoint = _endPointInformation.IPEndPoint; + + var listenSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + EnableRebinding(listenSocket); + + // Kestrel expects IPv6Any to bind to both IPv6 and IPv4 + if (endPoint.Address == IPAddress.IPv6Any) + { + listenSocket.DualMode = true; + } + + try + { + listenSocket.Bind(endPoint); + } + catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse) + { + throw new AddressInUseException(e.Message, e); + } + + // If requested port was "0", replace with assigned dynamic port. + if (_endPointInformation.IPEndPoint.Port == 0) + { + _endPointInformation.IPEndPoint = (IPEndPoint)listenSocket.LocalEndPoint; + } + + listenSocket.Listen(512); + + _listenSocket = listenSocket; + + _listenTask = Task.Run(() => RunAcceptLoopAsync()); + + return Task.CompletedTask; + } + + public async Task UnbindAsync() + { + if (_listenSocket != null) + { + _unbinding = true; + _listenSocket.Dispose(); + + Debug.Assert(_listenTask != null); + await _listenTask.ConfigureAwait(false); + + _unbinding = false; + _listenSocket = null; + _listenTask = null; + + if (_listenException != null) + { + var exInfo = ExceptionDispatchInfo.Capture(_listenException); + _listenException = null; + exInfo.Throw(); + } + } + } + + public Task StopAsync() + { + _memoryPool.Dispose(); + return Task.CompletedTask; + } + + private async Task RunAcceptLoopAsync() + { + try + { + while (true) + { + for (var schedulerIndex = 0; schedulerIndex < _numSchedulers; schedulerIndex++) + { + try + { + var acceptSocket = await _listenSocket.AcceptAsync(); + acceptSocket.NoDelay = _endPointInformation.NoDelay; + + var connection = new SocketConnection(acceptSocket, _memoryPool, _schedulers[schedulerIndex], _trace); + HandleConnectionAsync(connection); + } + catch (SocketException) when (!_unbinding) + { + _trace.ConnectionReset(connectionId: "(null)"); + } + } + } + } + catch (Exception ex) + { + if (_unbinding) + { + // Means we must be unbinding. Eat the exception. + } + else + { + _trace.LogCritical(ex, $"Unexpected exception in {nameof(SocketTransport)}.{nameof(RunAcceptLoopAsync)}."); + _listenException = ex; + + // Request shutdown so we can rethrow this exception + // in Stop which should be observable. + _appLifetime.StopApplication(); + } + } + } + + private void HandleConnectionAsync(SocketConnection connection) + { + try + { + _dispatcher.OnConnection(connection); + _ = connection.StartAsync(); + } + catch (Exception ex) + { + _trace.LogCritical(ex, $"Unexpected exception in {nameof(SocketTransport)}.{nameof(HandleConnectionAsync)}."); + } + } + + [DllImport("libc", SetLastError = true)] + private static extern int setsockopt(int socket, int level, int option_name, IntPtr option_value, uint option_len); + + private const int SOL_SOCKET_OSX = 0xffff; + private const int SO_REUSEADDR_OSX = 0x0004; + private const int SOL_SOCKET_LINUX = 0x0001; + private const int SO_REUSEADDR_LINUX = 0x0002; + + // Without setting SO_REUSEADDR on macOS and Linux, binding to a recently used endpoint can fail. + // https://github.com/dotnet/corefx/issues/24562 + private unsafe void EnableRebinding(Socket listenSocket) + { + var optionValue = 1; + var setsockoptStatus = 0; + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + setsockoptStatus = setsockopt(listenSocket.Handle.ToInt32(), SOL_SOCKET_LINUX, SO_REUSEADDR_LINUX, + (IntPtr)(&optionValue), sizeof(int)); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + setsockoptStatus = setsockopt(listenSocket.Handle.ToInt32(), SOL_SOCKET_OSX, SO_REUSEADDR_OSX, + (IntPtr)(&optionValue), sizeof(int)); + } + + if (setsockoptStatus != 0) + { + _trace.LogInformation("Setting SO_REUSEADDR failed with errno '{errno}'.", Marshal.GetLastWin32Error()); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportFactory.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportFactory.cs new file mode 100644 index 0000000000..7ddd1d918d --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportFactory.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging; +using Microsoft.AspNetCore.Hosting; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets +{ + public sealed class SocketTransportFactory : ITransportFactory + { + private readonly SocketTransportOptions _options; + private readonly IApplicationLifetime _appLifetime; + private readonly SocketsTrace _trace; + + public SocketTransportFactory( + IOptions options, + IApplicationLifetime applicationLifetime, + ILoggerFactory loggerFactory) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + if (applicationLifetime == null) + { + throw new ArgumentNullException(nameof(applicationLifetime)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + + _options = options.Value; + _appLifetime = applicationLifetime; + var logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"); + _trace = new SocketsTrace(logger); + } + + public ITransport Create(IEndPointInformation endPointInformation, IConnectionDispatcher dispatcher) + { + if (endPointInformation == null) + { + throw new ArgumentNullException(nameof(endPointInformation)); + } + + if (endPointInformation.Type != ListenType.IPEndPoint) + { + throw new ArgumentException(SocketsStrings.OnlyIPEndPointsSupported, nameof(endPointInformation)); + } + + if (dispatcher == null) + { + throw new ArgumentNullException(nameof(dispatcher)); + } + + return new SocketTransport(endPointInformation, dispatcher, _appLifetime, _options.IOQueueCount, _trace); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs new file mode 100644 index 0000000000..b6cec0a6d7 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets +{ + public class SocketTransportOptions + { + /// + /// The number of I/O queues used to process requests. Set to 0 to directly schedule I/O to the ThreadPool. + /// + /// + /// Defaults to rounded down and clamped between 1 and 16. + /// + public int IOQueueCount { get; set; } = Math.Min(Environment.ProcessorCount, 16); + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketsStrings.resx b/src/Servers/Kestrel/Transport.Sockets/src/SocketsStrings.resx new file mode 100644 index 0000000000..52b26c66bc --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketsStrings.resx @@ -0,0 +1,126 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Only ListenType.IPEndPoint is supported by the Socket Transport. https://go.microsoft.com/fwlink/?linkid=874850 + + + Transport is already bound. + + \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Sockets/src/WebHostBuilderSocketExtensions.cs b/src/Servers/Kestrel/Transport.Sockets/src/WebHostBuilderSocketExtensions.cs new file mode 100644 index 0000000000..95d27e46db --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/WebHostBuilderSocketExtensions.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class WebHostBuilderSocketExtensions + { + /// + /// Specify Sockets as the transport to be used by Kestrel. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseSockets(this IWebHostBuilder hostBuilder) + { + return hostBuilder.ConfigureServices(services => + { + services.AddSingleton(); + }); + } + + /// + /// Specify Sockets as the transport to be used by Kestrel. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder to configure. + /// + /// + /// A callback to configure Libuv options. + /// + /// + /// The Microsoft.AspNetCore.Hosting.IWebHostBuilder. + /// + public static IWebHostBuilder UseSockets(this IWebHostBuilder hostBuilder, Action configureOptions) + { + return hostBuilder.UseSockets().ConfigureServices(services => + { + services.Configure(configureOptions); + }); + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/baseline.netcore.json b/src/Servers/Kestrel/Transport.Sockets/src/baseline.netcore.json new file mode 100644 index 0000000000..7a73a41bfd --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/baseline.netcore.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/AsciiBytesToStringBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/AsciiBytesToStringBenchmark.cs new file mode 100644 index 0000000000..2f9c2423f5 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/AsciiBytesToStringBenchmark.cs @@ -0,0 +1,623 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using BenchmarkDotNet.Attributes; +using System; +using System.Collections.Generic; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class AsciiBytesToStringBenchmark + { + private const int Iterations = 100; + + private byte[] _asciiBytes; + private string _asciiString = new string('\0', 1024); + + [Params( + BenchmarkTypes.KeepAlive, + BenchmarkTypes.Accept, + BenchmarkTypes.UserAgent, + BenchmarkTypes.Cookie + )] + public BenchmarkTypes Type { get; set; } + + [GlobalSetup] + public void Setup() + { + switch (Type) + { + case BenchmarkTypes.KeepAlive: + _asciiBytes = Encoding.ASCII.GetBytes("keep-alive"); + break; + case BenchmarkTypes.Accept: + _asciiBytes = Encoding.ASCII.GetBytes("text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7"); + break; + case BenchmarkTypes.UserAgent: + _asciiBytes = Encoding.ASCII.GetBytes("Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36"); + break; + case BenchmarkTypes.Cookie: + _asciiBytes = Encoding.ASCII.GetBytes("prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric"); + break; + } + + Verify(); + } + + [Benchmark(OperationsPerInvoke = Iterations)] + public unsafe string EncodingAsciiGetChars() + { + for (uint i = 0; i < Iterations; i++) + { + fixed (byte* pBytes = &_asciiBytes[0]) + fixed (char* pString = _asciiString) + { + Encoding.ASCII.GetChars(pBytes, _asciiBytes.Length, pString, _asciiBytes.Length); + } + } + + return _asciiString; + } + + [Benchmark(Baseline = true, OperationsPerInvoke = Iterations)] + public unsafe byte[] KestrelBytesToString() + { + for (uint i = 0; i < Iterations; i++) + { + fixed (byte* pBytes = &_asciiBytes[0]) + fixed (char* pString = _asciiString) + { + TryGetAsciiString(pBytes, pString, _asciiBytes.Length); + } + } + + return _asciiBytes; + } + + [Benchmark(OperationsPerInvoke = Iterations)] + public unsafe byte[] AsciiBytesToStringVectorCheck() + { + for (uint i = 0; i < Iterations; i++) + { + fixed (byte* pBytes = &_asciiBytes[0]) + fixed (char* pString = _asciiString) + { + TryGetAsciiStringVectorCheck(pBytes, pString, _asciiBytes.Length); + } + } + + return _asciiBytes; + } + + [Benchmark(OperationsPerInvoke = Iterations)] + public unsafe byte[] AsciiBytesToStringVectorWiden() + { + // Widen Acceleration is post netcoreapp2.0 + for (uint i = 0; i < Iterations; i++) + { + fixed (byte* pBytes = &_asciiBytes[0]) + fixed (char* pString = _asciiString) + { + TryGetAsciiStringVectorWiden(pBytes, pString, _asciiBytes.Length); + } + } + + return _asciiBytes; + } + + [Benchmark(OperationsPerInvoke = Iterations)] + public unsafe byte[] AsciiBytesToStringSpanWiden() + { + // Widen Acceleration is post netcoreapp2.0 + for (uint i = 0; i < Iterations; i++) + { + fixed (char* pString = _asciiString) + { + TryGetAsciiStringWidenSpan(_asciiBytes, new Span(pString, _asciiString.Length)); + } + } + + return _asciiBytes; + } + + public static bool TryGetAsciiStringWidenSpan(ReadOnlySpan input, Span output) + { + // Start as valid + var isValid = true; + + do + { + // If Vector not-accelerated or remaining less than vector size + if (!Vector.IsHardwareAccelerated || input.Length < Vector.Count) + { + if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination + { + // 64-bit: Loop longs by default + while ((uint)sizeof(long) <= (uint)input.Length) + { + isValid &= CheckBytesInAsciiRange(MemoryMarshal.Cast(input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + + input = input.Slice(sizeof(long)); + output = output.Slice(sizeof(long)); + } + if ((uint)sizeof(int) <= (uint)input.Length) + { + isValid &= CheckBytesInAsciiRange(MemoryMarshal.Cast(input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input = input.Slice(sizeof(int)); + output = output.Slice(sizeof(int)); + } + } + else + { + // 32-bit: Loop ints by default + while ((uint)sizeof(int) <= (uint)input.Length) + { + isValid &= CheckBytesInAsciiRange(MemoryMarshal.Cast(input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input = input.Slice(sizeof(int)); + output = output.Slice(sizeof(int)); + } + } + if ((uint)sizeof(short) <= (uint)input.Length) + { + isValid &= CheckBytesInAsciiRange(MemoryMarshal.Cast(input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + + input = input.Slice(sizeof(short)); + output = output.Slice(sizeof(short)); + } + if ((uint)sizeof(byte) <= (uint)input.Length) + { + isValid &= CheckBytesInAsciiRange((sbyte)input[0]); + output[0] = (char)input[0]; + } + + return isValid; + } + + // do/while as entry condition already checked + do + { + var vector = MemoryMarshal.Cast>(input)[0]; + isValid &= CheckBytesInAsciiRange(vector); + Vector.Widen( + vector, + out MemoryMarshal.Cast>(output)[0], + out MemoryMarshal.Cast>(output)[1]); + + input = input.Slice(Vector.Count); + output = output.Slice(Vector.Count); + } while (input.Length >= Vector.Count); + + // Vector path done, loop back to do non-Vector + // If is a exact multiple of vector size, bail now + } while (input.Length > 0); + + return isValid; + } + + public static unsafe bool TryGetAsciiStringVectorWiden(byte* input, char* output, int count) + { + // Calculate end position + var end = input + count; + // Start as valid + var isValid = true; + + do + { + // If Vector not-accelerated or remaining less than vector size + if (!Vector.IsHardwareAccelerated || input > end - Vector.Count) + { + if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination + { + // 64-bit: Loop longs by default + while (input <= end - sizeof(long)) + { + isValid &= CheckBytesInAsciiRange(((long*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + + input += sizeof(long); + output += sizeof(long); + } + if (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + else + { + // 32-bit: Loop ints by default + while (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + if (input <= end - sizeof(short)) + { + isValid &= CheckBytesInAsciiRange(((short*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + + input += sizeof(short); + output += sizeof(short); + } + if (input < end) + { + isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]); + output[0] = (char)input[0]; + } + + return isValid; + } + + // do/while as entry condition already checked + do + { + var vector = Unsafe.AsRef>(input); + isValid &= CheckBytesInAsciiRange(vector); + Vector.Widen( + vector, + out Unsafe.AsRef>(output), + out Unsafe.AsRef>(output + Vector.Count)); + + input += Vector.Count; + output += Vector.Count; + } while (input <= end - Vector.Count); + + // Vector path done, loop back to do non-Vector + // If is a exact multiple of vector size, bail now + } while (input < end); + + return isValid; + } + + public static unsafe bool TryGetAsciiStringVectorCheck(byte* input, char* output, int count) + { + // Calculate end position + var end = input + count; + // Start as valid + var isValid = true; + do + { + // If Vector not-accelerated or remaining less than vector size + if (!Vector.IsHardwareAccelerated || input > end - Vector.Count) + { + if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination + { + // 64-bit: Loop longs by default + while (input <= end - sizeof(long)) + { + isValid &= CheckBytesInAsciiRange(((long*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + + input += sizeof(long); + output += sizeof(long); + } + if (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + else + { + // 32-bit: Loop ints by default + while (input <= end - sizeof(int)) + { + isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + + input += sizeof(int); + output += sizeof(int); + } + } + if (input <= end - sizeof(short)) + { + isValid &= CheckBytesInAsciiRange(((short*)input)[0]); + + output[0] = (char)input[0]; + output[1] = (char)input[1]; + + input += sizeof(short); + output += sizeof(short); + } + if (input < end) + { + isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]); + output[0] = (char)input[0]; + } + + return isValid; + } + + // do/while as entry condition already checked + do + { + isValid &= CheckBytesInAsciiRange(Unsafe.AsRef>(input)); + + // Vector.Widen is only netcoreapp2.1+ so let's do this manually + var i = 0; + do + { + // Vectors are min 16 byte, so lets do 16 byte loops + i += 16; + // Unrolled byte-wise widen + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + output[8] = (char)input[8]; + output[9] = (char)input[9]; + output[10] = (char)input[10]; + output[11] = (char)input[11]; + output[12] = (char)input[12]; + output[13] = (char)input[13]; + output[14] = (char)input[14]; + output[15] = (char)input[15]; + + input += 16; + output += 16; + } while (i < Vector.Count); + } while (input <= end - Vector.Count); + + // Vector path done, loop back to do non-Vector + // If is a exact multiple of vector size, bail now + } while (input < end); + + return isValid; + } + + public static unsafe bool TryGetAsciiString(byte* input, char* output, int count) + { + var i = 0; + sbyte* signedInput = (sbyte*)input; + + bool isValid = true; + while (i < count - 11) + { + isValid = isValid && *signedInput > 0 && *(signedInput + 1) > 0 && *(signedInput + 2) > 0 && + *(signedInput + 3) > 0 && *(signedInput + 4) > 0 && *(signedInput + 5) > 0 && *(signedInput + 6) > 0 && + *(signedInput + 7) > 0 && *(signedInput + 8) > 0 && *(signedInput + 9) > 0 && *(signedInput + 10) > 0 && + *(signedInput + 11) > 0; + + i += 12; + *(output) = (char)*(signedInput); + *(output + 1) = (char)*(signedInput + 1); + *(output + 2) = (char)*(signedInput + 2); + *(output + 3) = (char)*(signedInput + 3); + *(output + 4) = (char)*(signedInput + 4); + *(output + 5) = (char)*(signedInput + 5); + *(output + 6) = (char)*(signedInput + 6); + *(output + 7) = (char)*(signedInput + 7); + *(output + 8) = (char)*(signedInput + 8); + *(output + 9) = (char)*(signedInput + 9); + *(output + 10) = (char)*(signedInput + 10); + *(output + 11) = (char)*(signedInput + 11); + output += 12; + signedInput += 12; + } + if (i < count - 5) + { + isValid = isValid && *signedInput > 0 && *(signedInput + 1) > 0 && *(signedInput + 2) > 0 && + *(signedInput + 3) > 0 && *(signedInput + 4) > 0 && *(signedInput + 5) > 0; + + i += 6; + *(output) = (char)*(signedInput); + *(output + 1) = (char)*(signedInput + 1); + *(output + 2) = (char)*(signedInput + 2); + *(output + 3) = (char)*(signedInput + 3); + *(output + 4) = (char)*(signedInput + 4); + *(output + 5) = (char)*(signedInput + 5); + output += 6; + signedInput += 6; + } + if (i < count - 3) + { + isValid = isValid && *signedInput > 0 && *(signedInput + 1) > 0 && *(signedInput + 2) > 0 && + *(signedInput + 3) > 0; + + i += 4; + *(output) = (char)*(signedInput); + *(output + 1) = (char)*(signedInput + 1); + *(output + 2) = (char)*(signedInput + 2); + *(output + 3) = (char)*(signedInput + 3); + output += 4; + signedInput += 4; + } + + while (i < count) + { + isValid = isValid && *signedInput > 0; + + i++; + *output = (char)*signedInput; + output++; + signedInput++; + } + + return isValid; + } + + private static bool CheckBytesInAsciiRange(Vector check) + { + // Vectorized byte range check, signed byte > 0 for 1-127 + return Vector.GreaterThanAll(check, Vector.Zero); + } + + // Validate: bytes != 0 && bytes <= 127 + // Subtract 1 from all bytes to move 0 to high bits + // bitwise or with self to catch all > 127 bytes + // mask off high bits and check if 0 + + [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push + private static bool CheckBytesInAsciiRange(long check) + { + const long HighBits = unchecked((long)0x8080808080808080L); + return (((check - 0x0101010101010101L) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(int check) + { + const int HighBits = unchecked((int)0x80808080); + return (((check - 0x01010101) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(short check) + { + const short HighBits = unchecked((short)0x8080); + return (((short)(check - 0x0101) | check) & HighBits) == 0; + } + + private static bool CheckBytesInAsciiRange(sbyte check) + => check > 0; + + private void Verify() + { + var verification = EncodingAsciiGetChars().Substring(0, _asciiBytes.Length); + + BlankString('\0'); + EncodingAsciiGetChars(); + VerifyString(verification, '\0'); + BlankString(' '); + EncodingAsciiGetChars(); + VerifyString(verification, ' '); + + BlankString('\0'); + KestrelBytesToString(); + VerifyString(verification, '\0'); + BlankString(' '); + KestrelBytesToString(); + VerifyString(verification, ' '); + + BlankString('\0'); + AsciiBytesToStringVectorCheck(); + VerifyString(verification, '\0'); + BlankString(' '); + AsciiBytesToStringVectorCheck(); + VerifyString(verification, ' '); + + BlankString('\0'); + AsciiBytesToStringVectorWiden(); + VerifyString(verification, '\0'); + BlankString(' '); + AsciiBytesToStringVectorWiden(); + VerifyString(verification, ' '); + + BlankString('\0'); + AsciiBytesToStringSpanWiden(); + VerifyString(verification, '\0'); + BlankString(' '); + AsciiBytesToStringSpanWiden(); + VerifyString(verification, ' '); + } + + private unsafe void BlankString(char ch) + { + fixed (char* pString = _asciiString) + { + for (var i = 0; i < _asciiString.Length; i++) + { + *(pString + i) = ch; + } + } + } + + private unsafe void VerifyString(string verification, char ch) + { + fixed (char* pString = _asciiString) + { + var i = 0; + for (; i < verification.Length; i++) + { + if (*(pString + i) != verification[i]) throw new Exception($"Verify failed, saw {(int)*(pString + i)} expected {(int)verification[i]} at position {i}"); + } + for (; i < _asciiString.Length; i++) + { + if (*(pString + i) != ch) throw new Exception($"Verify failed, saw {(int)*(pString + i)} expected {(int)ch} at position {i}"); ; + } + } + } + + public enum BenchmarkTypes + { + KeepAlive, + Accept, + UserAgent, + Cookie, + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/AssemblyInfo.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/AssemblyInfo.cs new file mode 100644 index 0000000000..32248e0d1b --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/AssemblyInfo.cs @@ -0,0 +1 @@ +[assembly: BenchmarkDotNet.Attributes.AspNetCoreBenchmark] diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/DotSegmentRemovalBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/DotSegmentRemovalBenchmark.cs new file mode 100644 index 0000000000..5f943d97de --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/DotSegmentRemovalBenchmark.cs @@ -0,0 +1,58 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Text; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class DotSegmentRemovalBenchmark + { + // Immutable + private const string _noDotSegments = "/long/request/target/for/benchmarking/what/else/can/we/put/here"; + private const string _singleDotSegments = "/long/./request/./target/./for/./benchmarking/./what/./else/./can/./we/./put/./here"; + private const string _doubleDotSegments = "/long/../request/../target/../for/../benchmarking/../what/../else/../can/../we/../put/../here"; + + private readonly byte[] _noDotSegmentsAscii = Encoding.ASCII.GetBytes(_noDotSegments); + private readonly byte[] _singleDotSegmentsAscii = Encoding.ASCII.GetBytes(_singleDotSegments); + private readonly byte[] _doubleDotSegmentsAscii = Encoding.ASCII.GetBytes(_doubleDotSegments); + + private readonly byte[] _noDotSegmentsBytes = new byte[_noDotSegments.Length]; + private readonly byte[] _singleDotSegmentsBytes = new byte[_singleDotSegments.Length]; + private readonly byte[] _doubleDotSegmentsBytes = new byte[_doubleDotSegments.Length]; + + [Benchmark(Baseline = true)] + public unsafe int NoDotSegments() + { + _noDotSegmentsAscii.CopyTo(_noDotSegmentsBytes, 0); + + fixed (byte* start = _noDotSegmentsBytes) + { + return PathNormalizer.RemoveDotSegments(start, start + _noDotSegments.Length); + } + } + + [Benchmark] + public unsafe int SingleDotSegments() + { + _singleDotSegmentsAscii.CopyTo(_singleDotSegmentsBytes, 0); + + fixed (byte* start = _singleDotSegmentsBytes) + { + return PathNormalizer.RemoveDotSegments(start, start + _singleDotSegments.Length); + } + } + + [Benchmark] + public unsafe int DoubleDotSegments() + { + _doubleDotSegmentsAscii.CopyTo(_doubleDotSegmentsBytes, 0); + + fixed (byte* start = _doubleDotSegmentsBytes) + { + return PathNormalizer.RemoveDotSegments(start, start + _doubleDotSegments.Length); + } + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/ErrorUtilities.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/ErrorUtilities.cs new file mode 100644 index 0000000000..32ae14d571 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/ErrorUtilities.cs @@ -0,0 +1,17 @@ +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public static class ErrorUtilities + { + public static void ThrowInvalidRequestLine() + { + throw new InvalidOperationException("Invalid request line"); + } + + public static void ThrowInvalidRequestHeaders() + { + throw new InvalidOperationException("Invalid request headers"); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionBenchmark.cs new file mode 100644 index 0000000000..b90221b4a7 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionBenchmark.cs @@ -0,0 +1,115 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Performance.Mocks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class Http1ConnectionBenchmark + { + private const int InnerLoopCount = 512; + + private readonly HttpParser _parser = new HttpParser(); + + private ReadOnlySequence _buffer; + + public Http1Connection Connection { get; set; } + + [GlobalSetup] + public void Setup() + { + var memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + ServerOptions = new KestrelServerOptions(), + HttpParser = NullParser.Instance + }; + + var http1Connection = new Http1Connection(context: new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = memoryPool, + TimeoutControl = new MockTimeoutControl(), + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + + Connection = http1Connection; + } + + [Benchmark(Baseline = true, OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void PlaintextTechEmpower() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextTechEmpowerRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void LiveAspNet() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.LiveaspnetRequest); + ParseData(); + } + } + + private void InsertData(byte[] data) + { + _buffer = new ReadOnlySequence(data); + } + + private void ParseData() + { + if (!_parser.ParseRequestLine(new Adapter(this), _buffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + + _buffer = _buffer.Slice(consumed, _buffer.End); + + if (!_parser.ParseHeaders(new Adapter(this), _buffer, out consumed, out examined, out var consumedBytes)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + + Connection.EnsureHostHeaderExists(); + + Connection.Reset(); + } + + private struct Adapter : IHttpRequestLineHandler, IHttpHeadersHandler + { + public Http1ConnectionBenchmark RequestHandler; + + public Adapter(Http1ConnectionBenchmark requestHandler) + { + RequestHandler = requestHandler; + } + + public void OnHeader(Span name, Span value) + => RequestHandler.Connection.OnHeader(name, value); + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + => RequestHandler.Connection.OnStartLine(method, version, target, path, query, customMethod, pathEncoded); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionParsingOverheadBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionParsingOverheadBenchmark.cs new file mode 100644 index 0000000000..3046c5f1e9 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1ConnectionParsingOverheadBenchmark.cs @@ -0,0 +1,113 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Performance.Mocks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class Http1ConnectionParsingOverheadBenchmark + { + private const int InnerLoopCount = 512; + + public ReadOnlySequence _buffer; + public Http1Connection _http1Connection; + + [IterationSetup] + public void Setup() + { + var memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + ServerOptions = new KestrelServerOptions(), + HttpParser = NullParser.Instance + }; + + var http1Connection = new Http1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = memoryPool, + TimeoutControl = new MockTimeoutControl(), + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + + _http1Connection = http1Connection; + } + + [Benchmark(Baseline = true, OperationsPerInvoke = InnerLoopCount)] + public void Http1ConnectionOverheadTotal() + { + for (var i = 0; i < InnerLoopCount; i++) + { + ParseRequest(); + } + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void Http1ConnectionOverheadRequestLine() + { + for (var i = 0; i < InnerLoopCount; i++) + { + ParseRequestLine(); + } + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void Http1ConnectionOverheadRequestHeaders() + { + for (var i = 0; i < InnerLoopCount; i++) + { + ParseRequestHeaders(); + } + } + + private void ParseRequest() + { + _http1Connection.Reset(); + + if (!_http1Connection.TakeStartLine(_buffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestLine(); + } + + if (!_http1Connection.TakeMessageHeaders(_buffer, out consumed, out examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + } + + private void ParseRequestLine() + { + _http1Connection.Reset(); + + if (!_http1Connection.TakeStartLine(_buffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestLine(); + } + } + + private void ParseRequestHeaders() + { + _http1Connection.Reset(); + + if (!_http1Connection.TakeMessageHeaders(_buffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs new file mode 100644 index 0000000000..891d94ab0a --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Http1WritingBenchmark.cs @@ -0,0 +1,149 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class Http1WritingBenchmark + { + // Standard completed task + private static readonly Func _syncTaskFunc = (obj) => Task.CompletedTask; + // Non-standard completed task + private static readonly Task _psuedoAsyncTask = Task.FromResult(27); + private static readonly Func _psuedoAsyncTaskFunc = (obj) => _psuedoAsyncTask; + + private TestHttp1Connection _http1Connection; + private DuplexPipe.DuplexPipePair _pair; + private MemoryPool _memoryPool; + + private readonly byte[] _writeData = Encoding.ASCII.GetBytes("Hello, World!"); + + [GlobalSetup] + public void GlobalSetup() + { + _memoryPool = KestrelMemoryPool.Create(); + _http1Connection = MakeHttp1Connection(); + } + + [Params(true, false)] + public bool WithHeaders { get; set; } + + [Params(true, false)] + public bool Chunked { get; set; } + + [Params(Startup.None, Startup.Sync, Startup.Async)] + public Startup OnStarting { get; set; } + + [IterationSetup] + public void Setup() + { + _http1Connection.Reset(); + if (Chunked) + { + _http1Connection.RequestHeaders.Add("Transfer-Encoding", "chunked"); + } + else + { + _http1Connection.RequestHeaders.ContentLength = _writeData.Length; + } + + if (!WithHeaders) + { + _http1Connection.FlushAsync().GetAwaiter().GetResult(); + } + + ResetState(); + } + + private void ResetState() + { + if (WithHeaders) + { + _http1Connection.ResetState(); + + switch (OnStarting) + { + case Startup.Sync: + _http1Connection.OnStarting(_syncTaskFunc, null); + break; + case Startup.Async: + _http1Connection.OnStarting(_psuedoAsyncTaskFunc, null); + break; + } + } + } + + [Benchmark] + public Task WriteAsync() + { + ResetState(); + + return _http1Connection.ResponseBody.WriteAsync(_writeData, 0, _writeData.Length, default(CancellationToken)); + } + + private TestHttp1Connection MakeHttp1Connection() + { + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + _pair = pair; + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new TestHttp1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = _memoryPool, + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + http1Connection.InitializeStreams(MessageBody.ZeroContentLengthKeepAlive); + + return http1Connection; + } + + [IterationCleanup] + public void Cleanup() + { + var reader = _pair.Application.Input; + if (reader.TryRead(out var readResult)) + { + reader.AdvanceTo(readResult.Buffer.End); + } + } + + public enum Startup + { + None, + Sync, + Async + } + + [GlobalCleanup] + public void Dispose() + { + _memoryPool?.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/HttpParserBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/HttpParserBenchmark.cs new file mode 100644 index 0000000000..bb14ffb537 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/HttpParserBenchmark.cs @@ -0,0 +1,91 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class HttpParserBenchmark : IHttpRequestLineHandler, IHttpHeadersHandler + { + private readonly HttpParser _parser = new HttpParser(); + + private ReadOnlySequence _buffer; + + [Benchmark(Baseline = true, OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void PlaintextTechEmpower() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextTechEmpowerRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void LiveAspNet() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.LiveaspnetRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void Unicode() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.UnicodeRequest); + ParseData(); + } + } + + private void InsertData(byte[] data) + { + _buffer = new ReadOnlySequence(data); + } + + private void ParseData() + { + if (!_parser.ParseRequestLine(new Adapter(this), _buffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + + _buffer = _buffer.Slice(consumed, _buffer.End); + + if (!_parser.ParseHeaders(new Adapter(this), _buffer, out consumed, out examined, out var consumedBytes)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + } + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + } + + public void OnHeader(Span name, Span value) + { + } + + private struct Adapter : IHttpRequestLineHandler, IHttpHeadersHandler + { + public HttpParserBenchmark RequestHandler; + + public Adapter(HttpParserBenchmark requestHandler) + { + RequestHandler = requestHandler; + } + + public void OnHeader(Span name, Span value) + => RequestHandler.OnHeader(name, value); + + public void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + => RequestHandler.OnStartLine(method, version, target, path, query, customMethod, pathEncoded); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/HttpProtocolFeatureCollection.cs new file mode 100644 index 0000000000..24b2c6c4b6 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/HttpProtocolFeatureCollection.cs @@ -0,0 +1,125 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class HttpProtocolFeatureCollection + { + private readonly IFeatureCollection _collection; + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public IHttpRequestFeature GetViaTypeOf_First() + { + return (IHttpRequestFeature)_collection[typeof(IHttpRequestFeature)]; + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public IHttpRequestFeature GetViaGeneric_First() + { + return _collection.Get(); + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public IHttpSendFileFeature GetViaTypeOf_Last() + { + return (IHttpSendFileFeature)_collection[typeof(IHttpSendFileFeature)]; + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public IHttpSendFileFeature GetViaGeneric_Last() + { + return _collection.Get(); + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public object GetViaTypeOf_Custom() + { + return (IHttpCustomFeature)_collection[typeof(IHttpCustomFeature)]; + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public object GetViaGeneric_Custom() + { + return _collection.Get(); + } + + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public object GetViaTypeOf_NotFound() + { + return (IHttpNotFoundFeature)_collection[typeof(IHttpNotFoundFeature)]; + } + + [Benchmark] + [MethodImpl(MethodImplOptions.NoInlining)] + public object GetViaGeneric_NotFound() + { + return _collection.Get(); + } + + public HttpProtocolFeatureCollection() + { + var memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new Http1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = memoryPool, + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + + _collection = http1Connection; + } + + private class SendFileFeature : IHttpSendFileFeature + { + public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellation) + { + throw new NotImplementedException(); + } + } + + private interface IHttpCustomFeature + { + } + + private interface IHttpNotFoundFeature + { + } + } + +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/InMemoryTransportBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/InMemoryTransportBenchmark.cs new file mode 100644 index 0000000000..0e77ddd9c5 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/InMemoryTransportBenchmark.cs @@ -0,0 +1,243 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.DependencyInjection; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class InMemoryTransportBenchmark + { + private const string _plaintextExpectedResponse = + "HTTP/1.1 200 OK\r\n" + + "Date: Fri, 02 Mar 2018 18:37:05 GMT\r\n" + + "Content-Type: text/plain\r\n" + + "Server: Kestrel\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Hello, World!"; + + private static readonly string _plaintextPipelinedExpectedResponse = + string.Concat(Enumerable.Repeat(_plaintextExpectedResponse, RequestParsingData.Pipelining)); + + private IWebHost _host; + private InMemoryConnection _connection; + + [GlobalSetup(Target = nameof(Plaintext) + "," + nameof(PlaintextPipelined))] + public void GlobalSetupPlaintext() + { + var transportFactory = new InMemoryTransportFactory(connectionsPerEndPoint: 1); + + _host = new WebHostBuilder() + // Prevent VS from attaching to hosting startup which could impact results + .UseSetting("preventHostingStartup", "true") + .UseKestrel() + // Bind to a single non-HTTPS endpoint + .UseUrls("http://127.0.0.1:5000") + .ConfigureServices(services => services.AddSingleton(transportFactory)) + .Configure(app => app.UseMiddleware()) + .Build(); + + _host.Start(); + + // Ensure there is a single endpoint and single connection + _connection = transportFactory.Connections.Values.Single().Single(); + + ValidateResponseAsync(RequestParsingData.PlaintextTechEmpowerRequest, _plaintextExpectedResponse).Wait(); + ValidateResponseAsync(RequestParsingData.PlaintextTechEmpowerPipelinedRequests, _plaintextPipelinedExpectedResponse).Wait(); + } + + private async Task ValidateResponseAsync(byte[] request, string expectedResponse) + { + await _connection.SendRequestAsync(request); + var response = Encoding.ASCII.GetString(await _connection.GetResponseAsync(expectedResponse.Length)); + + // Exclude date header since the value changes on every request + var expectedResponseLines = expectedResponse.Split("\r\n").Where(s => !s.StartsWith("Date:")); + var responseLines = response.Split("\r\n").Where(s => !s.StartsWith("Date:")); + + if (!Enumerable.SequenceEqual(expectedResponseLines, responseLines)) + { + throw new InvalidOperationException(string.Join(Environment.NewLine, + "Invalid response", "Expected:", expectedResponse, "Actual:", response)); + } + } + + [GlobalCleanup] + public void GlobalCleanup() + { + _host.Dispose(); + } + + [Benchmark] + public async Task Plaintext() + { + await _connection.SendRequestAsync(RequestParsingData.PlaintextTechEmpowerRequest); + await _connection.ReadResponseAsync(_plaintextExpectedResponse.Length); + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.Pipelining)] + public async Task PlaintextPipelined() + { + await _connection.SendRequestAsync(RequestParsingData.PlaintextTechEmpowerPipelinedRequests); + await _connection.ReadResponseAsync(_plaintextPipelinedExpectedResponse.Length); + } + + public class InMemoryTransportFactory : ITransportFactory + { + private readonly int _connectionsPerEndPoint; + + private readonly Dictionary> _connections = + new Dictionary>(); + + public IReadOnlyDictionary> Connections => _connections; + + public InMemoryTransportFactory(int connectionsPerEndPoint) + { + _connectionsPerEndPoint = connectionsPerEndPoint; + } + + public ITransport Create(IEndPointInformation endPointInformation, IConnectionDispatcher handler) + { + var connections = new InMemoryConnection[_connectionsPerEndPoint]; + for (var i = 0; i < _connectionsPerEndPoint; i++) + { + connections[i] = new InMemoryConnection(); + } + + _connections.Add(endPointInformation, connections); + + return new InMemoryTransport(handler, connections); + } + } + + public class InMemoryTransport : ITransport + { + private readonly IConnectionDispatcher _dispatcher; + private readonly IReadOnlyList _connections; + + public InMemoryTransport(IConnectionDispatcher dispatcher, IReadOnlyList connections) + { + _dispatcher = dispatcher; + _connections = connections; + } + + public Task BindAsync() + { + foreach (var connection in _connections) + { + _dispatcher.OnConnection(connection); + } + + return Task.CompletedTask; + } + + public Task StopAsync() + { + return Task.CompletedTask; + } + + public Task UnbindAsync() + { + return Task.CompletedTask; + } + } + + public class InMemoryConnection : TransportConnection + { + public ValueTask SendRequestAsync(byte[] request) + { + return Input.WriteAsync(request); + } + + // Reads response as efficiently as possible (similar to LibuvTransport), but doesn't return anything + public async Task ReadResponseAsync(int length) + { + while (length > 0) + { + var result = await Output.ReadAsync(); + var buffer = result.Buffer; + length -= (int)buffer.Length; + Output.AdvanceTo(buffer.End); + } + + if (length < 0) + { + throw new InvalidOperationException($"Invalid response, length={length}"); + } + } + + // Returns response so it can be validated, but is slower and allocates more than ReadResponseAsync() + public async Task GetResponseAsync(int length) + { + while (true) + { + var result = await Output.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; + + try + { + if (buffer.Length >= length) + { + var response = buffer.Slice(0, length); + consumed = response.End; + examined = response.End; + return response.ToArray(); + } + } + finally + { + Output.AdvanceTo(consumed, examined); + } + } + } + } + + // Copied from https://github.com/aspnet/benchmarks/blob/dev/src/Benchmarks/Middleware/PlaintextMiddleware.cs + public class PlaintextMiddleware + { + private static readonly PathString _path = new PathString("/plaintext"); + private static readonly byte[] _helloWorldPayload = Encoding.UTF8.GetBytes("Hello, World!"); + + private readonly RequestDelegate _next; + + public PlaintextMiddleware(RequestDelegate next) + { + _next = next; + } + + public Task Invoke(HttpContext httpContext) + { + if (httpContext.Request.Path.StartsWithSegments(_path, StringComparison.Ordinal)) + { + return WriteResponse(httpContext.Response); + } + + return _next(httpContext); + } + + public static Task WriteResponse(HttpResponse response) + { + var payloadLength = _helloWorldPayload.Length; + response.StatusCode = 200; + response.ContentType = "text/plain"; + response.ContentLength = payloadLength; + return response.Body.WriteAsync(_helloWorldPayload, 0, payloadLength); + } + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/KnownStringsBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/KnownStringsBenchmark.cs new file mode 100644 index 0000000000..69bf2d5adb --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/KnownStringsBenchmark.cs @@ -0,0 +1,159 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class KnownStringsBenchmark + { + static byte[] _methodConnect = Encoding.ASCII.GetBytes("CONNECT "); + static byte[] _methodDelete = Encoding.ASCII.GetBytes("DELETE \0"); + static byte[] _methodGet = Encoding.ASCII.GetBytes("GET "); + static byte[] _methodHead = Encoding.ASCII.GetBytes("HEAD \0\0\0"); + static byte[] _methodPatch = Encoding.ASCII.GetBytes("PATCH \0\0"); + static byte[] _methodPost = Encoding.ASCII.GetBytes("POST \0\0\0"); + static byte[] _methodPut = Encoding.ASCII.GetBytes("PUT \0\0\0\0"); + static byte[] _methodOptions = Encoding.ASCII.GetBytes("OPTIONS "); + static byte[] _methodTrace = Encoding.ASCII.GetBytes("TRACE \0\0"); + + static byte[] _version = Encoding.UTF8.GetBytes("HTTP/1.1\r\n"); + const int loops = 1000; + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_GET() + { + Span data = _methodGet; + + return GetKnownMethod(data); + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_CONNECT() + { + Span data = _methodConnect; + + return GetKnownMethod(data); + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_DELETE() + { + Span data = _methodDelete; + + return GetKnownMethod(data); + } + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_HEAD() + { + Span data = _methodHead; + + return GetKnownMethod(data); + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_PATCH() + { + Span data = _methodPatch; + + return GetKnownMethod(data); + } + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_POST() + { + Span data = _methodPost; + + return GetKnownMethod(data); + } + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_PUT() + { + Span data = _methodPut; + + return GetKnownMethod(data); + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_OPTIONS() + { + Span data = _methodOptions; + + return GetKnownMethod(data); + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownMethod_TRACE() + { + Span data = _methodTrace; + + return GetKnownMethod(data); + } + + private int GetKnownMethod(Span data) + { + int len = 0; + HttpMethod method; + + for (int i = 0; i < loops; i++) + { + data.GetKnownMethod(out method, out var length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + data.GetKnownMethod(out method, out length); + len += length; + } + return len; + } + + [Benchmark(OperationsPerInvoke = loops * 10)] + public int GetKnownVersion_HTTP1_1() + { + int len = 0; + HttpVersion version; + Span data = _version; + for (int i = 0; i < loops; i++) + { + data.GetKnownVersion(out version, out var length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + data.GetKnownVersion(out version, out length); + len += length; + } + return len; + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj b/src/Servers/Kestrel/perf/Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj new file mode 100644 index 0000000000..1998485c20 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Microsoft.AspNetCore.Server.Kestrel.Performance.csproj @@ -0,0 +1,27 @@ + + + + netcoreapp2.0 + Exe + true + true + false + + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTimeoutControl.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTimeoutControl.cs new file mode 100644 index 0000000000..879c86bab6 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTimeoutControl.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance.Mocks +{ + public class MockTimeoutControl : ITimeoutControl + { + public void CancelTimeout() + { + } + + public void ResetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + public void SetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + public void StartTimingReads() + { + } + + public void StopTimingReads() + { + } + + public void PauseTimingReads() + { + } + + public void ResumeTimingReads() + { + } + + public void BytesRead(long count) + { + } + + public void StartTimingWrite(long size) + { + } + + public void StopTimingWrite() + { + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTrace.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTrace.cs new file mode 100644 index 0000000000..8386402a78 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/MockTrace.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class MockTrace : IKestrelTrace + { + public void ApplicationError(string connectionId, string requestId, Exception ex) { } + public IDisposable BeginScope(TState state) => null; + public void ConnectionBadRequest(string connectionId, BadHttpRequestException ex) { } + public void ConnectionDisconnect(string connectionId) { } + public void ConnectionError(string connectionId, Exception ex) { } + public void ConnectionHeadResponseBodyWrite(string connectionId, long count) { } + public void ConnectionKeepAlive(string connectionId) { } + public void ConnectionPause(string connectionId) { } + public void ConnectionRead(string connectionId, int count) { } + public void ConnectionReadFin(string connectionId) { } + public void ConnectionReset(string connectionId) { } + public void ConnectionResume(string connectionId) { } + public void ConnectionRejected(string connectionId) { } + public void ConnectionStart(string connectionId) { } + public void ConnectionStop(string connectionId) { } + public void ConnectionWrite(string connectionId, int count) { } + public void ConnectionWriteCallback(string connectionId, int status) { } + public void ConnectionWriteFin(string connectionId) { } + public void ConnectionWroteFin(string connectionId, int status) { } + public bool IsEnabled(LogLevel logLevel) => false; + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) { } + public void NotAllConnectionsAborted() { } + public void NotAllConnectionsClosedGracefully() { } + public void RequestProcessingError(string connectionId, Exception ex) { } + public void HeartbeatSlow(TimeSpan interval, DateTimeOffset now) { } + public void ApplicationNeverCompleted(string connectionId) { } + public void RequestBodyStart(string connectionId, string traceIdentifier) { } + public void RequestBodyDone(string connectionId, string traceIdentifier) { } + public void RequestBodyNotEntirelyRead(string connectionId, string traceIdentifier) { } + public void RequestBodyDrainTimedOut(string connectionId, string traceIdentifier) { } + public void RequestBodyMininumDataRateNotSatisfied(string connectionId, string traceIdentifier, double rate) { } + public void ResponseMininumDataRateNotSatisfied(string connectionId, string traceIdentifier) { } + public void ApplicationAbortedConnection(string connectionId, string traceIdentifier) { } + public void Http2ConnectionError(string connectionId, Http2ConnectionErrorException ex) { } + public void Http2StreamError(string connectionId, Http2StreamErrorException ex) { } + public void HPackDecodingError(string connectionId, int streamId, HPackDecodingException ex) { } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/NullParser.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/NullParser.cs new file mode 100644 index 0000000000..556eaa7dbb --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/Mocks/NullParser.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class NullParser : IHttpParser where TRequestHandler : struct, IHttpHeadersHandler, IHttpRequestLineHandler + { + private readonly byte[] _startLine = Encoding.ASCII.GetBytes("GET /plaintext HTTP/1.1\r\n"); + private readonly byte[] _target = Encoding.ASCII.GetBytes("/plaintext"); + private readonly byte[] _hostHeaderName = Encoding.ASCII.GetBytes("Host"); + private readonly byte[] _hostHeaderValue = Encoding.ASCII.GetBytes("www.example.com"); + private readonly byte[] _acceptHeaderName = Encoding.ASCII.GetBytes("Accept"); + private readonly byte[] _acceptHeaderValue = Encoding.ASCII.GetBytes("text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7\r\n\r\n"); + private readonly byte[] _connectionHeaderName = Encoding.ASCII.GetBytes("Connection"); + private readonly byte[] _connectionHeaderValue = Encoding.ASCII.GetBytes("keep-alive"); + + public static readonly NullParser Instance = new NullParser(); + + public bool ParseHeaders(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined, out int consumedBytes) + { + handler.OnHeader(new Span(_hostHeaderName), new Span(_hostHeaderValue)); + handler.OnHeader(new Span(_acceptHeaderName), new Span(_acceptHeaderValue)); + handler.OnHeader(new Span(_connectionHeaderName), new Span(_connectionHeaderValue)); + + consumedBytes = 0; + consumed = buffer.Start; + examined = buffer.End; + + return true; + } + + public bool ParseRequestLine(TRequestHandler handler, in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + handler.OnStartLine(HttpMethod.Get, + HttpVersion.Http11, + new Span(_target), + new Span(_target), + Span.Empty, + Span.Empty, + false); + + consumed = buffer.Start; + examined = buffer.End; + + return true; + } + + public void Reset() + { + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/PipeThroughputBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/PipeThroughputBenchmark.cs new file mode 100644 index 0000000000..b647567715 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/PipeThroughputBenchmark.cs @@ -0,0 +1,74 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class PipeThroughputBenchmark + { + private const int _writeLenght = 57; + private const int InnerLoopCount = 512; + + private Pipe _pipe; + private MemoryPool _memoryPool; + + [IterationSetup] + public void Setup() + { + _memoryPool = KestrelMemoryPool.Create(); + _pipe = new Pipe(new PipeOptions(_memoryPool)); + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void ParseLiveAspNetTwoTasks() + { + var writing = Task.Run(async () => + { + for (int i = 0; i < InnerLoopCount; i++) + { + _pipe.Writer.GetMemory(_writeLenght); + _pipe.Writer.Advance(_writeLenght); + await _pipe.Writer.FlushAsync(); + } + }); + + var reading = Task.Run(async () => + { + long remaining = InnerLoopCount * _writeLenght; + while (remaining != 0) + { + var result = await _pipe.Reader.ReadAsync(); + remaining -= result.Buffer.Length; + _pipe.Reader.AdvanceTo(result.Buffer.End, result.Buffer.End); + } + }); + + Task.WaitAll(writing, reading); + } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void ParseLiveAspNetInline() + { + for (int i = 0; i < InnerLoopCount; i++) + { + _pipe.Writer.GetMemory(_writeLenght); + _pipe.Writer.Advance(_writeLenght); + _pipe.Writer.FlushAsync().GetAwaiter().GetResult(); + var result = _pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + _pipe.Reader.AdvanceTo(result.Buffer.End, result.Buffer.End); + } + } + + [IterationCleanup] + public void Cleanup() + { + _memoryPool.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/README.md b/src/Servers/Kestrel/perf/Kestrel.Performance/README.md new file mode 100644 index 0000000000..91991f82ff --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/README.md @@ -0,0 +1,11 @@ +Compile the solution in Release mode (so Kestrel is available in release) + +To run a specific benchmark add it as parameter +``` +dotnet run -f netcoreapp2.0 -c Release RequestParsing +``` +To run all use `All` as parameter +``` +dotnet run -f netcoreapp2.0 -c Release All +``` +Using no parameter will list all available benchmarks diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingBenchmark.cs new file mode 100644 index 0000000000..2d881588d2 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingBenchmark.cs @@ -0,0 +1,216 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Performance.Mocks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class RequestParsingBenchmark + { + private MemoryPool _memoryPool; + + public Pipe Pipe { get; set; } + + public Http1Connection Http1Connection { get; set; } + + [IterationSetup] + public void Setup() + { + _memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new Http1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = _memoryPool, + Application = pair.Application, + Transport = pair.Transport, + TimeoutControl = new MockTimeoutControl() + }); + + http1Connection.Reset(); + + Http1Connection = http1Connection; + Pipe = new Pipe(new PipeOptions(_memoryPool)); + } + + [Benchmark(Baseline = true, OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void PlaintextTechEmpower() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextTechEmpowerRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void PlaintextAbsoluteUri() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextAbsoluteUriRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount * RequestParsingData.Pipelining)] + public void PipelinedPlaintextTechEmpower() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextTechEmpowerPipelinedRequests); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount * RequestParsingData.Pipelining)] + public void PipelinedPlaintextTechEmpowerDrainBuffer() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.PlaintextTechEmpowerPipelinedRequests); + ParseDataDrainBuffer(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void LiveAspNet() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.LiveaspnetRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount * RequestParsingData.Pipelining)] + public void PipelinedLiveAspNet() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.LiveaspnetPipelinedRequests); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount)] + public void Unicode() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.UnicodeRequest); + ParseData(); + } + } + + [Benchmark(OperationsPerInvoke = RequestParsingData.InnerLoopCount * RequestParsingData.Pipelining)] + public void UnicodePipelined() + { + for (var i = 0; i < RequestParsingData.InnerLoopCount; i++) + { + InsertData(RequestParsingData.UnicodePipelinedRequests); + ParseData(); + } + } + + private void InsertData(byte[] bytes) + { + Pipe.Writer.Write(bytes); + // There should not be any backpressure and task completes immediately + Pipe.Writer.FlushAsync().GetAwaiter().GetResult(); + } + + private void ParseDataDrainBuffer() + { + var awaitable = Pipe.Reader.ReadAsync(); + if (!awaitable.IsCompleted) + { + // No more data + return; + } + + var readableBuffer = awaitable.GetAwaiter().GetResult().Buffer; + do + { + Http1Connection.Reset(); + + if (!Http1Connection.TakeStartLine(readableBuffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestLine(); + } + + readableBuffer = readableBuffer.Slice(consumed); + + if (!Http1Connection.TakeMessageHeaders(readableBuffer, out consumed, out examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + + readableBuffer = readableBuffer.Slice(consumed); + } + while (readableBuffer.Length > 0); + + Pipe.Reader.AdvanceTo(readableBuffer.End); + } + + private void ParseData() + { + do + { + var awaitable = Pipe.Reader.ReadAsync(); + if (!awaitable.IsCompleted) + { + // No more data + return; + } + + var result = awaitable.GetAwaiter().GetResult(); + var readableBuffer = result.Buffer; + + Http1Connection.Reset(); + + if (!Http1Connection.TakeStartLine(readableBuffer, out var consumed, out var examined)) + { + ErrorUtilities.ThrowInvalidRequestLine(); + } + Pipe.Reader.AdvanceTo(consumed, examined); + + result = Pipe.Reader.ReadAsync().GetAwaiter().GetResult(); + readableBuffer = result.Buffer; + + if (!Http1Connection.TakeMessageHeaders(readableBuffer, out consumed, out examined)) + { + ErrorUtilities.ThrowInvalidRequestHeaders(); + } + Pipe.Reader.AdvanceTo(consumed, examined); + } + while (true); + } + + + [IterationCleanup] + public void Cleanup() + { + _memoryPool.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingData.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingData.cs new file mode 100644 index 0000000000..5c496960bb --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/RequestParsingData.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class RequestParsingData + { + public const int InnerLoopCount = 512; + + public const int Pipelining = 16; + + private const string _plaintextTechEmpowerRequest = + "GET /plaintext HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Accept: text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7\r\n" + + "Connection: keep-alive\r\n" + + "\r\n"; + + // edge-casey - client's don't normally send this + private const string _plaintextAbsoluteUriRequest = + "GET http://localhost/plaintext HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Accept: text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7\r\n" + + "Connection: keep-alive\r\n" + + "\r\n"; + + private const string _liveaspnetRequest = + "GET / HTTP/1.1\r\n" + + "Host: live.asp.net\r\n" + + "Connection: keep-alive\r\n" + + "Upgrade-Insecure-Requests: 1\r\n" + + "User-Agent: Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.99 Safari/537.36\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8\r\n" + + "DNT: 1\r\n" + + "Accept-Encoding: gzip, deflate, sdch, br\r\n" + + "Accept-Language: en-US,en;q=0.8\r\n" + + "Cookie: __unam=7a67379-1s65dc575c4-6d778abe-1; omniID=9519gfde_3347_4762_8762_df51458c8ec2\r\n" + + "\r\n"; + + private const string _unicodeRequest = + "GET /questions/40148683/why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric HTTP/1.1\r\n" + + "Accept: text/html, application/xhtml+xml, image/jxr, */*\r\n" + + "Accept-Language: en-US,en-GB;q=0.7,en;q=0.3\r\n" + + "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/52.0.2743.116 Safari/537.36 Edge/15.14965\r\n" + + "Accept-Encoding: gzip, deflate\r\n" + + "Host: stackoverflow.com\r\n" + + "Connection: Keep-Alive\r\n" + + "Cache-Control: max-age=0\r\n" + + "Upgrade-Insecure-Requests: 1\r\n" + + "DNT: 1\r\n" + + "Referer: http://stackoverflow.com/?tab=month\r\n" + + "Pragma: no-cache\r\n" + + "Cookie: prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric\r\n" + + "\r\n"; + + public static readonly byte[] PlaintextTechEmpowerPipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(_plaintextTechEmpowerRequest, Pipelining))); + public static readonly byte[] PlaintextTechEmpowerRequest = Encoding.ASCII.GetBytes(_plaintextTechEmpowerRequest); + + public static readonly byte[] PlaintextAbsoluteUriRequest = Encoding.ASCII.GetBytes(_plaintextAbsoluteUriRequest); + + public static readonly byte[] LiveaspnetPipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(_liveaspnetRequest, Pipelining))); + public static readonly byte[] LiveaspnetRequest = Encoding.ASCII.GetBytes(_liveaspnetRequest); + + public static readonly byte[] UnicodePipelinedRequests = Encoding.ASCII.GetBytes(string.Concat(Enumerable.Repeat(_unicodeRequest, Pipelining))); + public static readonly byte[] UnicodeRequest = Encoding.ASCII.GetBytes(_unicodeRequest); + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeaderCollectionBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeaderCollectionBenchmark.cs new file mode 100644 index 0000000000..e04f007222 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeaderCollectionBenchmark.cs @@ -0,0 +1,221 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Text; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class ResponseHeaderCollectionBenchmark + { + private const int InnerLoopCount = 512; + + private static readonly byte[] _bytesServer = Encoding.ASCII.GetBytes("\r\nServer: Kestrel"); + private static readonly DateHeaderValueManager _dateHeaderValueManager = new DateHeaderValueManager(); + private HttpResponseHeaders _responseHeadersDirect; + private HttpResponse _response; + + public enum BenchmarkTypes + { + ContentLengthNumeric, + ContentLengthString, + Plaintext, + Common, + Unknown + } + + [Params( + BenchmarkTypes.ContentLengthNumeric, + BenchmarkTypes.ContentLengthString, + BenchmarkTypes.Plaintext, + BenchmarkTypes.Common, + BenchmarkTypes.Unknown + )] + public BenchmarkTypes Type { get; set; } + + [Benchmark(OperationsPerInvoke = InnerLoopCount)] + public void SetHeaders() + { + switch (Type) + { + case BenchmarkTypes.ContentLengthNumeric: + ContentLengthNumeric(InnerLoopCount); + break; + case BenchmarkTypes.ContentLengthString: + ContentLengthString(InnerLoopCount); + break; + case BenchmarkTypes.Plaintext: + Plaintext(InnerLoopCount); + break; + case BenchmarkTypes.Common: + Common(InnerLoopCount); + break; + case BenchmarkTypes.Unknown: + Unknown(InnerLoopCount); + break; + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ContentLengthNumeric(int count) + { + for (var i = 0; i < count; i++) + { + _responseHeadersDirect.Reset(); + + _response.ContentLength = 0; + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ContentLengthString(int count) + { + for (var i = 0; i < count; i++) + { + _responseHeadersDirect.Reset(); + + _response.Headers["Content-Length"] = "0"; + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void Plaintext(int count) + { + for (var i = 0; i < count; i++) + { + _responseHeadersDirect.Reset(); + + _response.StatusCode = 200; + _response.ContentType = "text/plain"; + _response.ContentLength = 13; + + var dateHeaderValues = _dateHeaderValueManager.GetDateHeaderValues(); + _responseHeadersDirect.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + _responseHeadersDirect.SetRawServer("Kestrel", _bytesServer); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void Common(int count) + { + for (var i = 0; i < count; i++) + { + _responseHeadersDirect.Reset(); + + _response.StatusCode = 200; + _response.ContentType = "text/css"; + _response.ContentLength = 421; + + var headers = _response.Headers; + + headers["Connection"] = "Close"; + headers["Cache-Control"] = "public, max-age=30672000"; + headers["Vary"] = "Accept-Encoding"; + headers["Content-Encoding"] = "gzip"; + headers["Expires"] = "Fri, 12 Jan 2018 22:01:55 GMT"; + headers["Last-Modified"] = "Wed, 22 Jun 2016 20:08:29 GMT"; + headers["Set-Cookie"] = "prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric"; + headers["ETag"] = "\"54ef7954-1078\""; + headers["Transfer-Encoding"] = "chunked"; + headers["Content-Language"] = "en-gb"; + headers["Upgrade"] = "websocket"; + headers["Via"] = "1.1 varnish"; + headers["Access-Control-Allow-Origin"] = "*"; + headers["Access-Control-Allow-credentials"] = "true"; + headers["Access-Control-Expose-Headers"] = "Client-Protocol, Content-Length, Content-Type, X-Bandwidth-Est, X-Bandwidth-Est2, X-Bandwidth-Est-Comp, X-Bandwidth-Avg, X-Walltime-Ms, X-Sequence-Num"; + + var dateHeaderValues = _dateHeaderValueManager.GetDateHeaderValues(); + _responseHeadersDirect.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + _responseHeadersDirect.SetRawServer("Kestrel", _bytesServer); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void Unknown(int count) + { + for (var i = 0; i < count; i++) + { + _responseHeadersDirect.Reset(); + + _response.StatusCode = 200; + _response.ContentType = "text/plain"; + _response.ContentLength = 13; + + var headers = _response.Headers; + + headers["Link"] = "; rel=\"canonical\""; + headers["X-Ua-Compatible"] = "IE=Edge"; + headers["X-Powered-By"] = "ASP.NET"; + headers["X-Content-Type-Options"] = "nosniff"; + headers["X-Xss-Protection"] = "1; mode=block"; + headers["X-Frame-Options"] = "SAMEORIGIN"; + headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains; preload"; + headers["Content-Security-Policy"] = "default-src 'none'; script-src 'self' cdnjs.cloudflare.com code.jquery.com scotthelme.disqus.com a.disquscdn.com www.google-analytics.com go.disqus.com platform.twitter.com cdn.syndication.twimg.com; style-src 'self' a.disquscdn.com fonts.googleapis.com cdnjs.cloudflare.com platform.twitter.com; img-src 'self' data: www.gravatar.com www.google-analytics.com links.services.disqus.com referrer.disqus.com a.disquscdn.com cdn.syndication.twimg.com syndication.twitter.com pbs.twimg.com platform.twitter.com abs.twimg.com; child-src fusiontables.googleusercontent.com fusiontables.google.com www.google.com disqus.com www.youtube.com syndication.twitter.com platform.twitter.com; frame-src fusiontables.googleusercontent.com fusiontables.google.com www.google.com disqus.com www.youtube.com syndication.twitter.com platform.twitter.com; connect-src 'self' links.services.disqus.com; font-src 'self' cdnjs.cloudflare.com fonts.gstatic.com fonts.googleapis.com; form-action 'self'; upgrade-insecure-requests;"; + + var dateHeaderValues = _dateHeaderValueManager.GetDateHeaderValues(); + _responseHeadersDirect.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + _responseHeadersDirect.SetRawServer("Kestrel", _bytesServer); + } + } + + [IterationSetup] + public void Setup() + { + var memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new Http1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = memoryPool, + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + + _responseHeadersDirect = (HttpResponseHeaders)http1Connection.ResponseHeaders; + var context = new DefaultHttpContext(http1Connection); + _response = new DefaultHttpResponse(context); + + switch (Type) + { + case BenchmarkTypes.ContentLengthNumeric: + ContentLengthNumeric(1); + break; + case BenchmarkTypes.ContentLengthString: + ContentLengthString(1); + break; + case BenchmarkTypes.Plaintext: + Plaintext(1); + break; + case BenchmarkTypes.Common: + Common(1); + break; + case BenchmarkTypes.Unknown: + Unknown(1); + break; + } + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs new file mode 100644 index 0000000000..69df7fdb88 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs @@ -0,0 +1,158 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Performance.Mocks; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class ResponseHeadersWritingBenchmark + { + private static readonly byte[] _helloWorldPayload = Encoding.ASCII.GetBytes("Hello, World!"); + + private TestHttp1Connection _http1Connection; + + private MemoryPool _memoryPool; + + [Params( + BenchmarkTypes.TechEmpowerPlaintext, + BenchmarkTypes.PlaintextChunked, + BenchmarkTypes.PlaintextWithCookie, + BenchmarkTypes.PlaintextChunkedWithCookie, + BenchmarkTypes.LiveAspNet + )] + public BenchmarkTypes Type { get; set; } + + [Benchmark] + public async Task Output() + { + _http1Connection.Reset(); + _http1Connection.StatusCode = 200; + _http1Connection.HttpVersionEnum = HttpVersion.Http11; + _http1Connection.KeepAlive = true; + + Task writeTask = Task.CompletedTask; + switch (Type) + { + case BenchmarkTypes.TechEmpowerPlaintext: + writeTask = TechEmpowerPlaintext(); + break; + case BenchmarkTypes.PlaintextChunked: + writeTask = PlaintextChunked(); + break; + case BenchmarkTypes.PlaintextWithCookie: + writeTask = PlaintextWithCookie(); + break; + case BenchmarkTypes.PlaintextChunkedWithCookie: + writeTask = PlaintextChunkedWithCookie(); + break; + case BenchmarkTypes.LiveAspNet: + writeTask = LiveAspNet(); + break; + } + + await writeTask; + await _http1Connection.ProduceEndAsync(); + } + + private Task TechEmpowerPlaintext() + { + var responseHeaders = _http1Connection.ResponseHeaders; + responseHeaders["Content-Type"] = "text/plain"; + responseHeaders.ContentLength = _helloWorldPayload.Length; + return _http1Connection.WriteAsync(new ArraySegment(_helloWorldPayload), default(CancellationToken)); + } + + private Task PlaintextChunked() + { + var responseHeaders = _http1Connection.ResponseHeaders; + responseHeaders["Content-Type"] = "text/plain"; + return _http1Connection.WriteAsync(new ArraySegment(_helloWorldPayload), default(CancellationToken)); + } + + private Task LiveAspNet() + { + var responseHeaders = _http1Connection.ResponseHeaders; + responseHeaders["Content-Encoding"] = "gzip"; + responseHeaders["Content-Type"] = "text/html; charset=utf-8"; + responseHeaders["Strict-Transport-Security"] = "max-age=31536000; includeSubdomains"; + responseHeaders["Vary"] = "Accept-Encoding"; + responseHeaders["X-Powered-By"] = "ASP.NET"; + return _http1Connection.WriteAsync(new ArraySegment(_helloWorldPayload), default(CancellationToken)); + } + + private Task PlaintextWithCookie() + { + var responseHeaders = _http1Connection.ResponseHeaders; + responseHeaders["Content-Type"] = "text/plain"; + responseHeaders["Set-Cookie"] = "prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric"; + responseHeaders.ContentLength = _helloWorldPayload.Length; + return _http1Connection.WriteAsync(new ArraySegment(_helloWorldPayload), default(CancellationToken)); + } + + private Task PlaintextChunkedWithCookie() + { + var responseHeaders = _http1Connection.ResponseHeaders; + responseHeaders["Content-Type"] = "text/plain"; + responseHeaders["Set-Cookie"] = "prov=20629ccd-8b0f-e8ef-2935-cd26609fc0bc; __qca=P0-1591065732-1479167353442; _ga=GA1.2.1298898376.1479167354; _gat=1; sgt=id=9519gfde_3347_4762_8762_df51458c8ec2; acct=t=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric&s=why-is-%e0%a5%a7%e0%a5%a8%e0%a5%a9-numeric"; + return _http1Connection.WriteAsync(new ArraySegment(_helloWorldPayload), default(CancellationToken)); + } + + [IterationSetup] + public void Setup() + { + _memoryPool = KestrelMemoryPool.Create(); + var options = new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = new MockTrace(), + HttpParser = new HttpParser() + }; + + var http1Connection = new TestHttp1Connection(new Http1ConnectionContext + { + ServiceContext = serviceContext, + ConnectionFeatures = new FeatureCollection(), + MemoryPool = _memoryPool, + TimeoutControl = new MockTimeoutControl(), + Application = pair.Application, + Transport = pair.Transport + }); + + http1Connection.Reset(); + + _http1Connection = http1Connection; + } + + [IterationCleanup] + public void Cleanup() + { + _memoryPool.Dispose(); + } + + public enum BenchmarkTypes + { + TechEmpowerPlaintext, + PlaintextChunked, + PlaintextWithCookie, + PlaintextChunkedWithCookie, + LiveAspNet + } + } +} diff --git a/src/Servers/Kestrel/perf/Kestrel.Performance/StringUtilitiesBenchmark.cs b/src/Servers/Kestrel/perf/Kestrel.Performance/StringUtilitiesBenchmark.cs new file mode 100644 index 0000000000..e2b2e9eab1 --- /dev/null +++ b/src/Servers/Kestrel/perf/Kestrel.Performance/StringUtilitiesBenchmark.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Performance +{ + public class StringUtilitiesBenchmark + { + private const int Iterations = 500_000; + + [Benchmark(Baseline = true, OperationsPerInvoke = Iterations)] + public void UintToString() + { + var connectionId = CorrelationIdGenerator.GetNextId(); + for (uint i = 0; i < Iterations; i++) + { + var id = connectionId + ':' + i.ToString("X8"); + } + } + + [Benchmark(OperationsPerInvoke = Iterations)] + public void ConcatAsHexSuffix() + { + var connectionId = CorrelationIdGenerator.GetNextId(); + for (uint i = 0; i < Iterations; i++) + { + var id = StringUtilities.ConcatAsHexSuffix(connectionId, ':', i); + } + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/AsciiString.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/AsciiString.cs new file mode 100644 index 0000000000..f84d42c978 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/AsciiString.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace PlatformBenchmarks +{ + public struct AsciiString : IEquatable + { + private readonly byte[] _data; + + public AsciiString(string s) => _data = Encoding.ASCII.GetBytes(s); + + public int Length => _data.Length; + + public ReadOnlySpan AsSpan() => _data; + + public static implicit operator ReadOnlySpan(AsciiString str) => str._data; + public static implicit operator byte[] (AsciiString str) => str._data; + + public static implicit operator AsciiString(string str) => new AsciiString(str); + + public override string ToString() => HttpUtilities.GetAsciiStringNonNullCharacters(_data); + public static explicit operator string(AsciiString str) => str.ToString(); + + public bool Equals(AsciiString other) => ReferenceEquals(_data, other._data) || SequenceEqual(_data, other._data); + private bool SequenceEqual(byte[] data1, byte[] data2) => new Span(data1).SequenceEqual(data2); + + public static bool operator ==(AsciiString a, AsciiString b) => a.Equals(b); + public static bool operator !=(AsciiString a, AsciiString b) => !a.Equals(b); + public override bool Equals(object other) => (other is AsciiString) && Equals((AsciiString)other); + + public override int GetHashCode() + { + // Copied from x64 version of string.GetLegacyNonRandomizedHashCode() + // https://github.com/dotnet/coreclr/blob/master/src/mscorlib/src/System/String.Comparison.cs + var data = _data; + int hash1 = 5381; + int hash2 = hash1; + foreach (int b in data) + { + hash1 = ((hash1 << 5) + hash1) ^ b; + } + return hash1 + (hash2 * 1566083941); + } + + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkApplication.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkApplication.cs new file mode 100644 index 0000000000..6551bee358 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkApplication.cs @@ -0,0 +1,156 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Utf8Json; + +namespace PlatformBenchmarks +{ + public class BenchmarkApplication : HttpConnection + { + private static AsciiString _crlf = "\r\n"; + private static AsciiString _eoh = "\r\n\r\n"; // End Of Headers + private static AsciiString _http11OK = "HTTP/1.1 200 OK\r\n"; + private static AsciiString _headerServer = "Server: Custom"; + private static AsciiString _headerContentLength = "Content-Length: "; + private static AsciiString _headerContentLengthZero = "Content-Length: 0\r\n"; + private static AsciiString _headerContentTypeText = "Content-Type: text/plain\r\n"; + private static AsciiString _headerContentTypeJson = "Content-Type: application/json\r\n"; + + + private static AsciiString _plainTextBody = "Hello, World!"; + + private static class Paths + { + public static AsciiString Plaintext = "/plaintext"; + public static AsciiString Json = "/json"; + } + + private bool _isPlainText; + private bool _isJson; + + public override void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + if (path.StartsWith(Paths.Plaintext) && method == HttpMethod.Get) + { + _isPlainText = true; + } + else if (path.StartsWith(Paths.Json) && method == HttpMethod.Get) + { + _isJson = true; + } + else + { + _isPlainText = false; + _isJson = false; + } + } + + public override void OnHeader(Span name, Span value) + { + } + + public override ValueTask ProcessRequestAsync() + { + if (_isPlainText) + { + PlainText(Writer); + } + else if (_isJson) + { + Json(Writer); + } + else + { + Default(Writer); + } + + return default; + } + + public override async ValueTask OnReadCompletedAsync() + { + await Writer.FlushAsync(); + } + private static void PlainText(PipeWriter pipeWriter) + { + var writer = new BufferWriter(pipeWriter); + // HTTP 1.1 OK + writer.Write(_http11OK); + + // Server headers + writer.Write(_headerServer); + + // Date header + writer.Write(DateHeader.HeaderBytes); + + // Content-Type header + writer.Write(_headerContentTypeText); + + // Content-Length header + writer.Write(_headerContentLength); + writer.WriteNumeric((ulong)_plainTextBody.Length); + + // End of headers + writer.Write(_eoh); + + // Body + writer.Write(_plainTextBody); + writer.Commit(); + } + + private static void Json(PipeWriter pipeWriter) + { + var writer = new BufferWriter(pipeWriter); + + // HTTP 1.1 OK + writer.Write(_http11OK); + + // Server headers + writer.Write(_headerServer); + + // Date header + writer.Write(DateHeader.HeaderBytes); + + // Content-Type header + writer.Write(_headerContentTypeJson); + + // Content-Length header + writer.Write(_headerContentLength); + var jsonPayload = JsonSerializer.SerializeUnsafe(new { message = "Hello, World!" }); + writer.WriteNumeric((ulong)jsonPayload.Count); + + // End of headers + writer.Write(_eoh); + + // Body + writer.Write(jsonPayload); + writer.Commit(); + } + + private static void Default(PipeWriter pipeWriter) + { + var writer = new BufferWriter(pipeWriter); + + // HTTP 1.1 OK + writer.Write(_http11OK); + + // Server headers + writer.Write(_headerServer); + + // Date header + writer.Write(DateHeader.HeaderBytes); + + // Content-Length 0 + writer.Write(_headerContentLengthZero); + + // End of headers + writer.Write(_crlf); + writer.Commit(); + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkConfigurationHelpers.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkConfigurationHelpers.cs new file mode 100644 index 0000000000..c6d2cdbdf9 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/BenchmarkConfigurationHelpers.cs @@ -0,0 +1,81 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.Extensions.Configuration; + +namespace PlatformBenchmarks +{ + public static class BenchmarkConfigurationHelpers + { + public static IWebHostBuilder UseBenchmarksConfiguration(this IWebHostBuilder builder, IConfiguration configuration) + { + builder.UseConfiguration(configuration); + + // Handle the transport type + var webHost = builder.GetSetting("KestrelTransport"); + + // Handle the thread count + var threadCountRaw = builder.GetSetting("threadCount"); + int? theadCount = null; + + if (!string.IsNullOrEmpty(threadCountRaw) && + Int32.TryParse(threadCountRaw, out var value)) + { + theadCount = value; + } + + if (string.Equals(webHost, "Libuv", StringComparison.OrdinalIgnoreCase)) + { + builder.UseLibuv(options => + { + if (theadCount.HasValue) + { + options.ThreadCount = theadCount.Value; + } + }); + } + else if (string.Equals(webHost, "Sockets", StringComparison.OrdinalIgnoreCase)) + { + builder.UseSockets(options => + { + if (theadCount.HasValue) + { + options.IOQueueCount = theadCount.Value; + } + }); + } + + return builder; + } + + public static IPEndPoint CreateIPEndPoint(this IConfiguration config) + { + var url = config["server.urls"] ?? config["urls"]; + + if (string.IsNullOrEmpty(url)) + { + return new IPEndPoint(IPAddress.Loopback, 8080); + } + + var address = ServerAddress.FromUrl(url); + + IPAddress ip; + + if (string.Equals(address.Host, "localhost", StringComparison.OrdinalIgnoreCase)) + { + ip = IPAddress.Loopback; + } + else if (!IPAddress.TryParse(address.Host, out ip)) + { + ip = IPAddress.IPv6Any; + } + + return new IPEndPoint(ip, address.Port); + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/DateHeader.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/DateHeader.cs new file mode 100644 index 0000000000..5cfbc5c3b2 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/DateHeader.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers.Text; +using System.Diagnostics; +using System.Text; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + /// + /// Manages the generation of the date header value. + /// + internal static class DateHeader + { + const int prefixLength = 8; // "\r\nDate: ".Length + const int dateTimeRLength = 29; // Wed, 14 Mar 2018 14:20:00 GMT + const int suffixLength = 2; // crlf + const int suffixIndex = dateTimeRLength + prefixLength; + + private static readonly Timer s_timer = new Timer((s) => { + SetDateValues(DateTimeOffset.UtcNow); + }, null, 1000, 1000); + + private static byte[] s_headerBytesMaster = new byte[prefixLength + dateTimeRLength + suffixLength]; + private static byte[] s_headerBytesScratch = new byte[prefixLength + dateTimeRLength + suffixLength]; + + static DateHeader() + { + var utf8 = Encoding.ASCII.GetBytes("\r\nDate: ").AsSpan(); + utf8.CopyTo(s_headerBytesMaster); + utf8.CopyTo(s_headerBytesScratch); + s_headerBytesMaster[suffixIndex] = (byte)'\r'; + s_headerBytesMaster[suffixIndex + 1] = (byte)'\n'; + s_headerBytesScratch[suffixIndex] = (byte)'\r'; + s_headerBytesScratch[suffixIndex + 1] = (byte)'\n'; + SetDateValues(DateTimeOffset.UtcNow); + } + + public static ReadOnlySpan HeaderBytes => s_headerBytesMaster; + + private static void SetDateValues(DateTimeOffset value) + { + lock (s_headerBytesScratch) + { + if (!Utf8Formatter.TryFormat(value, s_headerBytesScratch.AsSpan(prefixLength), out int written, 'R')) + { + throw new Exception("date time format failed"); + } + Debug.Assert(written == dateTimeRLength); + var temp = s_headerBytesMaster; + s_headerBytesMaster = s_headerBytesScratch; + s_headerBytesScratch = temp; + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/HttpApplication.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/HttpApplication.cs new file mode 100644 index 0000000000..698e8952b0 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/HttpApplication.cs @@ -0,0 +1,161 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace PlatformBenchmarks +{ + public static class HttpApplicationConnectionBuilderExtensions + { + public static IConnectionBuilder UseHttpApplication(this IConnectionBuilder builder) where TConnection : HttpConnection, new() + { + return builder.Use(next => new HttpApplication().ExecuteAsync); + } + } + + public class HttpApplication where TConnection : HttpConnection, new() + { + public Task ExecuteAsync(ConnectionContext connection) + { + var parser = new HttpParser(); + + var httpConnection = new TConnection + { + Parser = parser, + Reader = connection.Transport.Input, + Writer = connection.Transport.Output + }; + return httpConnection.ExecuteAsync(); + } + } + + public class HttpConnection : IHttpHeadersHandler, IHttpRequestLineHandler + { + private State _state; + + public PipeReader Reader { get; set; } + public PipeWriter Writer { get; set; } + + internal HttpParser Parser { get; set; } + + public virtual void OnHeader(Span name, Span value) + { + + } + + public virtual void OnStartLine(HttpMethod method, HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + + } + + public virtual ValueTask ProcessRequestAsync() + { + return default; + } + + public virtual ValueTask OnReadCompletedAsync() + { + return default; + } + + public async Task ExecuteAsync() + { + try + { + await ProcessRequestsAsync(); + + Reader.Complete(); + } + catch (Exception ex) + { + Reader.Complete(ex); + } + finally + { + Writer.Complete(); + } + } + + private async Task ProcessRequestsAsync() + { + while (true) + { + var task = Reader.ReadAsync(); + + if (!task.IsCompleted) + { + // No more data in the input + await OnReadCompletedAsync(); + } + + var result = await task; + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; + + if (!buffer.IsEmpty) + { + ParseHttpRequest(buffer, out consumed, out examined); + + if (_state != State.Body && result.IsCompleted) + { + ThrowUnexpectedEndOfData(); + } + } + + Reader.AdvanceTo(consumed, examined); + + if (_state == State.Body) + { + await ProcessRequestAsync(); + + _state = State.StartLine; + } + else if (result.IsCompleted) + { + break; + } + } + } + + private void ParseHttpRequest(in ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) + { + consumed = buffer.Start; + examined = buffer.End; + + var parsingStartLine = _state == State.StartLine; + if (parsingStartLine) + { + if (Parser.ParseRequestLine(this, buffer, out consumed, out examined)) + { + _state = State.Headers; + } + } + + if (_state == State.Headers) + { + if (Parser.ParseHeaders(this, parsingStartLine ? buffer.Slice(consumed) : buffer, out consumed, out examined, out int consumedBytes)) + { + _state = State.Body; + } + } + } + + private static void ThrowUnexpectedEndOfData() + { + throw new InvalidOperationException("Unexpected end of data!"); + } + + private enum State + { + StartLine, + Headers, + Body + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/PlatformBenchmarks.csproj b/src/Servers/Kestrel/perf/PlatformBenchmarks/PlatformBenchmarks.csproj new file mode 100644 index 0000000000..faeef6ecb2 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/PlatformBenchmarks.csproj @@ -0,0 +1,26 @@ + + + + netcoreapp2.1 + Exe + latest + true + true + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/Program.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/Program.cs new file mode 100644 index 0000000000..7776cbd1ce --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/Program.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; + +namespace PlatformBenchmarks +{ + public class Program + { + public static void Main(string[] args) + { + BuildWebHost(args).Run(); + } + + public static IWebHost BuildWebHost(string[] args) + { + var config = new ConfigurationBuilder() + .AddEnvironmentVariables(prefix: "ASPNETCORE_") + .AddCommandLine(args) + .Build(); + + var host = new WebHostBuilder() + .UseBenchmarksConfiguration(config) + .UseKestrel((context, options) => + { + IPEndPoint endPoint = context.Configuration.CreateIPEndPoint(); + + options.Listen(endPoint, builder => + { + builder.UseHttpApplication(); + }); + }) + .UseStartup() + .Build(); + + return host; + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/Startup.cs b/src/Servers/Kestrel/perf/PlatformBenchmarks/Startup.cs new file mode 100644 index 0000000000..bfc1f70375 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/Startup.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net.WebSockets; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; + +namespace PlatformBenchmarks +{ + public class Startup + { + public void Configure(IApplicationBuilder app) + { + + } + } +} diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.json.json b/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.json.json new file mode 100644 index 0000000000..157cee28dc --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.json.json @@ -0,0 +1,15 @@ +{ + "Default": { + "Client": "Wrk", + "PresetHeaders": "Json", + + "Source": { + "Repository": "https://github.com/aspnet/KestrelHttpServer.git", + "BranchOrCommit": "dev", + "Project": "benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj" + } + }, + "JsonPlatform": { + "Path": "/json" + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.plaintext.json b/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.plaintext.json new file mode 100644 index 0000000000..5ad58aee15 --- /dev/null +++ b/src/Servers/Kestrel/perf/PlatformBenchmarks/benchmarks.plaintext.json @@ -0,0 +1,26 @@ +{ + "Default": { + "Client": "Wrk", + "PresetHeaders": "Plaintext", + "ClientProperties": { + "ScriptName": "pipeline", + "PipelineDepth": 16 + }, + "Source": { + "Repository": "https://github.com/aspnet/KestrelHttpServer.git", + "BranchOrCommit": "dev", + "Project": "benchmarkapps/PlatformBenchmarks/PlatformBenchmarks.csproj" + }, + "Port": 8080 + }, + "PlaintextPlatform": { + "Path": "/plaintext" + }, + "PlaintextNonPipelinedPlatform": { + "Path": "/plaintext", + "ClientProperties": { + "ScriptName": "", + "PipelineDepth": 0 + } + } +} diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/Dockerfile b/src/Servers/Kestrel/samples/Http2SampleApp/Dockerfile new file mode 100644 index 0000000000..e93d563bde --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/Dockerfile @@ -0,0 +1,14 @@ +FROM microsoft/aspnetcore:2.0.0-stretch + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + libssl-dev && \ + rm -rf /var/lib/apt/lists/* + +ARG CONFIGURATION=Debug + +WORKDIR /app + +COPY ./bin/${CONFIGURATION}/netcoreapp2.0/publish/ /app + +ENTRYPOINT [ "/usr/bin/dotnet", "/app/Http2SampleApp.dll" ] diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/Http2SampleApp.csproj b/src/Servers/Kestrel/samples/Http2SampleApp/Http2SampleApp.csproj new file mode 100644 index 0000000000..08078fe3cf --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/Http2SampleApp.csproj @@ -0,0 +1,20 @@ + + + + netcoreapp2.1 + false + true + + + + + + + + + + PreserveNewest + + + + diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/Program.cs b/src/Servers/Kestrel/samples/Http2SampleApp/Program.cs new file mode 100644 index 0000000000..6e023ec03d --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/Program.cs @@ -0,0 +1,44 @@ +using System; +using System.IO; +using System.Net; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace Http2SampleApp +{ + public class Program + { + public static void Main(string[] args) + { + var hostBuilder = new WebHostBuilder() + .ConfigureLogging((_, factory) => + { + // Set logging to the MAX. + factory.SetMinimumLevel(LogLevel.Trace); + factory.AddConsole(); + }) + .UseKestrel((context, options) => + { + var basePort = context.Configuration.GetValue("BASE_PORT") ?? 5000; + + // Run callbacks on the transport thread + options.ApplicationSchedulingMode = SchedulingMode.Inline; + + options.Listen(IPAddress.Any, basePort, listenOptions => + { + // This only works becuase InternalsVisibleTo is enabled for this sample. + listenOptions.Protocols = HttpProtocols.Http1AndHttp2; + listenOptions.UseHttps("testCert.pfx", "testPassword"); + listenOptions.UseConnectionLogging(); + }); + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseStartup(); + + hostBuilder.Build().Run(); + } + } +} diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/Startup.cs b/src/Servers/Kestrel/samples/Http2SampleApp/Startup.cs new file mode 100644 index 0000000000..904e07cbb8 --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/Startup.cs @@ -0,0 +1,23 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; + +namespace Http2SampleApp +{ + public class Startup + { + + public void ConfigureServices(IServiceCollection services) + { + } + + public void Configure(IApplicationBuilder app, IHostingEnvironment env) + { + app.Run(context => + { + return context.Response.WriteAsync("Hello World! " + context.Request.Protocol); + }); + } + } +} diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.ps1 b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.ps1 new file mode 100644 index 0000000000..eda82ace6f --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.ps1 @@ -0,0 +1,3 @@ +dotnet publish --framework netcoreapp2.0 "$PSScriptRoot/../Http2SampleApp.csproj" + +docker build -t kestrel-http2-sample (Convert-Path "$PSScriptRoot/..") diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.sh b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.sh new file mode 100644 index 0000000000..ca226f0b53 --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/build-docker.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +dotnet publish --framework netcoreapp2.0 "$DIR/../Http2SampleApp.csproj" + +docker build -t kestrel-http2-sample "$DIR/.." diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.ps1 b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.ps1 new file mode 100644 index 0000000000..7b371b6dde --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.ps1 @@ -0,0 +1 @@ +docker run -p 5000:5000 -it --rm kestrel-http2-sample diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.sh b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.sh new file mode 100644 index 0000000000..3039b34a98 --- /dev/null +++ b/src/Servers/Kestrel/samples/Http2SampleApp/scripts/run-docker.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +docker run -it -p 5000:5000 --rm kestrel-http2-sample diff --git a/src/Servers/Kestrel/samples/Http2SampleApp/testCert.pfx b/src/Servers/Kestrel/samples/Http2SampleApp/testCert.pfx new file mode 100644 index 0000000000..7118908c2d Binary files /dev/null and b/src/Servers/Kestrel/samples/Http2SampleApp/testCert.pfx differ diff --git a/src/Servers/Kestrel/samples/LargeResponseApp/LargeResponseApp.csproj b/src/Servers/Kestrel/samples/LargeResponseApp/LargeResponseApp.csproj new file mode 100644 index 0000000000..16d7aa18d1 --- /dev/null +++ b/src/Servers/Kestrel/samples/LargeResponseApp/LargeResponseApp.csproj @@ -0,0 +1,13 @@ + + + + netcoreapp2.0;net461 + false + true + + + + + + + diff --git a/src/Servers/Kestrel/samples/LargeResponseApp/Startup.cs b/src/Servers/Kestrel/samples/LargeResponseApp/Startup.cs new file mode 100644 index 0000000000..c2b7d30a1b --- /dev/null +++ b/src/Servers/Kestrel/samples/LargeResponseApp/Startup.cs @@ -0,0 +1,54 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; + +namespace LargeResponseApp +{ + public class Startup + { + private const int _chunkSize = 4096; + private const int _defaultNumChunks = 16; + private static byte[] _chunk = Encoding.UTF8.GetBytes(new string('a', _chunkSize)); + + public void Configure(IApplicationBuilder app) + { + app.Run(async (context) => + { + int numChunks; + var path = context.Request.Path; + if (!path.HasValue || !int.TryParse(path.Value.Substring(1), out numChunks)) + { + numChunks = _defaultNumChunks; + } + + context.Response.ContentLength = _chunkSize * numChunks; + context.Response.ContentType = "text/plain"; + + for (int i = 0; i < numChunks; i++) + { + await context.Response.Body.WriteAsync(_chunk, 0, _chunkSize).ConfigureAwait(false); + } + }); + } + + public static Task Main(string[] args) + { + var host = new WebHostBuilder() + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, 5001); + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseStartup() + .Build(); + + return host.RunAsync(); + } + } +} diff --git a/src/Servers/Kestrel/samples/PlaintextApp/PlaintextApp.csproj b/src/Servers/Kestrel/samples/PlaintextApp/PlaintextApp.csproj new file mode 100644 index 0000000000..f2c0c8a820 --- /dev/null +++ b/src/Servers/Kestrel/samples/PlaintextApp/PlaintextApp.csproj @@ -0,0 +1,13 @@ + + + + netcoreapp2.1;net461 + false + true + + + + + + + diff --git a/src/Servers/Kestrel/samples/PlaintextApp/Startup.cs b/src/Servers/Kestrel/samples/PlaintextApp/Startup.cs new file mode 100644 index 0000000000..28da0a6f2c --- /dev/null +++ b/src/Servers/Kestrel/samples/PlaintextApp/Startup.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; + +namespace PlaintextApp +{ + public class Startup + { + private static readonly byte[] _helloWorldBytes = Encoding.UTF8.GetBytes("Hello, World!"); + + public void Configure(IApplicationBuilder app) + { + app.Run((httpContext) => + { + var response = httpContext.Response; + response.StatusCode = 200; + response.ContentType = "text/plain"; + + var helloWorld = _helloWorldBytes; + response.ContentLength = helloWorld.Length; + return response.Body.WriteAsync(helloWorld, 0, helloWorld.Length); + }); + } + + public static Task Main(string[] args) + { + var host = new WebHostBuilder() + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, 5001); + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseStartup() + .Build(); + + return host.RunAsync(); + } + } +} diff --git a/src/Servers/Kestrel/samples/SampleApp/SampleApp.csproj b/src/Servers/Kestrel/samples/SampleApp/SampleApp.csproj new file mode 100644 index 0000000000..fff5a6c3bd --- /dev/null +++ b/src/Servers/Kestrel/samples/SampleApp/SampleApp.csproj @@ -0,0 +1,35 @@ + + + + netcoreapp2.1;netcoreapp2.0;net461 + false + true + + + + + + + + + + + + + PreserveNewest + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/src/Servers/Kestrel/samples/SampleApp/Startup.cs b/src/Servers/Kestrel/samples/SampleApp/Startup.cs new file mode 100644 index 0000000000..bb070eb348 --- /dev/null +++ b/src/Servers/Kestrel/samples/SampleApp/Startup.cs @@ -0,0 +1,169 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Net; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace SampleApp +{ + public class Startup + { + public void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger("Default"); + + app.Run(async context => + { + var connectionFeature = context.Connection; + logger.LogDebug($"Peer: {connectionFeature.RemoteIpAddress?.ToString()}:{connectionFeature.RemotePort}" + + $"{Environment.NewLine}" + + $"Sock: {connectionFeature.LocalIpAddress?.ToString()}:{connectionFeature.LocalPort}"); + + var response = $"hello, world{Environment.NewLine}"; + context.Response.ContentLength = response.Length; + context.Response.ContentType = "text/plain"; + await context.Response.WriteAsync(response); + }); + } + + public static Task Main(string[] args) + { + TaskScheduler.UnobservedTaskException += (sender, e) => + { + Console.WriteLine("Unobserved exception: {0}", e.Exception); + }; + + var hostBuilder = new WebHostBuilder() + .ConfigureLogging((_, factory) => + { + factory.AddConsole(); + }) + .ConfigureAppConfiguration((hostingContext, config) => + { + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true); + }) + .UseKestrel((context, options) => + { + if (context.HostingEnvironment.IsDevelopment()) + { + ShowConfig(context.Configuration); + } + + var basePort = context.Configuration.GetValue("BASE_PORT") ?? 5000; + + options.ConfigureEndpointDefaults(opt => + { + opt.NoDelay = true; + }); + + options.ConfigureHttpsDefaults(httpsOptions => + { + httpsOptions.SslProtocols = SslProtocols.Tls12; + }); + + // Run callbacks on the transport thread + options.ApplicationSchedulingMode = SchedulingMode.Inline; + + options.Listen(IPAddress.Loopback, basePort, listenOptions => + { + // Uncomment the following to enable Nagle's algorithm for this endpoint. + //listenOptions.NoDelay = false; + + listenOptions.UseConnectionLogging(); + }); + + options.Listen(IPAddress.Loopback, basePort + 1, listenOptions => + { + listenOptions.UseHttps("testCert.pfx", "testPassword"); + listenOptions.UseConnectionLogging(); + }); + + options.ListenLocalhost(basePort + 2, listenOptions => + { + // Use default dev cert + listenOptions.UseHttps(); + }); + + options.ListenAnyIP(basePort + 3); + + options.ListenAnyIP(basePort + 4, listenOptions => + { + listenOptions.UseHttps(StoreName.My, "localhost", allowInvalid: true); + }); + + options.ListenAnyIP(basePort + 5, listenOptions => + { + listenOptions.UseHttps(httpsOptions => + { + var localhostCert = CertificateLoader.LoadFromStoreCert("localhost", "My", StoreLocation.CurrentUser, allowInvalid: true); + httpsOptions.ServerCertificateSelector = (features, name) => + { + // Here you would check the name, select an appropriate cert, and provide a fallback or fail for null names. + return localhostCert; + }; + }); + }); + + options + .Configure() + .Endpoint(IPAddress.Loopback, basePort + 6) + .LocalhostEndpoint(basePort + 7) + .Load(); + + options + .Configure(context.Configuration.GetSection("Kestrel")) + .Endpoint("NamedEndpoint", opt => + { + opt.ListenOptions.NoDelay = true; + }) + .Endpoint("NamedHttpsEndpoint", opt => + { + opt.HttpsOptions.SslProtocols = SslProtocols.Tls12; + }); + + options.UseSystemd(); + + // The following section should be used to demo sockets + //options.ListenUnixSocket("/tmp/kestrel-test.sock"); + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseStartup(); + + if (string.Equals(Process.GetCurrentProcess().Id.ToString(), Environment.GetEnvironmentVariable("LISTEN_PID"))) + { + // Use libuv if activated by systemd, since that's currently the only transport that supports being passed a socket handle. + hostBuilder.UseLibuv(options => + { + // Uncomment the following line to change the default number of libuv threads for all endpoints. + // options.ThreadCount = 4; + }); + } + + return hostBuilder.Build().RunAsync(); + } + + private static void ShowConfig(IConfiguration config) + { + foreach (var pair in config.GetChildren()) + { + Console.WriteLine($"{pair.Path} - {pair.Value}"); + ShowConfig(pair); + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json b/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json new file mode 100644 index 0000000000..741bd03aee --- /dev/null +++ b/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json @@ -0,0 +1,15 @@ +{ + "Kestrel": { + "Endpoints": { + "NamedEndpoint": { "Url": "http://localhost:6000" }, + "NamedHttpsEndpoint": { + "Url": "https://localhost:6443", + "Certificate": { + "Subject": "localhost", + "Store": "My", + "AllowInvalid": true + } + } + } + } +} diff --git a/src/Servers/Kestrel/samples/SampleApp/appsettings.Production.json b/src/Servers/Kestrel/samples/SampleApp/appsettings.Production.json new file mode 100644 index 0000000000..8719fb89b7 --- /dev/null +++ b/src/Servers/Kestrel/samples/SampleApp/appsettings.Production.json @@ -0,0 +1,14 @@ +{ + "Kestrel": { + "Endpoints": { + "NamedEndpoint": { "Url": "http://*:6000" }, + "NamedHttpsEndpoint": { + "Url": "https://*:6443", + "Certificate": { + "Path": "testCert.pfx", + "Password": "testPassword" + } + } + } + } +} diff --git a/src/Servers/Kestrel/samples/SampleApp/appsettings.json b/src/Servers/Kestrel/samples/SampleApp/appsettings.json new file mode 100644 index 0000000000..cd77ddd218 --- /dev/null +++ b/src/Servers/Kestrel/samples/SampleApp/appsettings.json @@ -0,0 +1,6 @@ +{ + "Kestrel": { + "Endpoints": { + } + } +} diff --git a/src/Servers/Kestrel/samples/SampleApp/testCert.pfx b/src/Servers/Kestrel/samples/SampleApp/testCert.pfx new file mode 100644 index 0000000000..7118908c2d Binary files /dev/null and b/src/Servers/Kestrel/samples/SampleApp/testCert.pfx differ diff --git a/src/Servers/Kestrel/samples/SystemdTestApp/Startup.cs b/src/Servers/Kestrel/samples/SystemdTestApp/Startup.cs new file mode 100644 index 0000000000..0b3c5e05de --- /dev/null +++ b/src/Servers/Kestrel/samples/SystemdTestApp/Startup.cs @@ -0,0 +1,92 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +namespace SystemdTestApp +{ + public class Startup + { + public void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger("Default"); + + app.Run(async context => + { + var connectionFeature = context.Connection; + logger.LogDebug($"Peer: {connectionFeature.RemoteIpAddress?.ToString()}:{connectionFeature.RemotePort}" + + $"{Environment.NewLine}" + + $"Sock: {connectionFeature.LocalIpAddress?.ToString()}:{connectionFeature.LocalPort}"); + + var response = $"hello, world{Environment.NewLine}"; + context.Response.ContentLength = response.Length; + context.Response.ContentType = "text/plain"; + await context.Response.WriteAsync(response); + }); + } + + public static Task Main(string[] args) + { + TaskScheduler.UnobservedTaskException += (sender, e) => + { + Console.WriteLine("Unobserved exception: {0}", e.Exception); + }; + + var hostBuilder = new WebHostBuilder() + .ConfigureLogging((_, factory) => + { + factory.AddConsole(); + }) + .UseKestrel((context, options) => + { + var basePort = context.Configuration.GetValue("BASE_PORT") ?? 5000; + + // Run callbacks on the transport thread + options.ApplicationSchedulingMode = SchedulingMode.Inline; + + options.Listen(IPAddress.Loopback, basePort, listenOptions => + { + // Uncomment the following to enable Nagle's algorithm for this endpoint. + //listenOptions.NoDelay = false; + + listenOptions.UseConnectionLogging(); + }); + + options.Listen(IPAddress.Loopback, basePort + 1, listenOptions => + { + listenOptions.UseHttps("testCert.pfx", "testPassword"); + listenOptions.UseConnectionLogging(); + }); + + options.UseSystemd(); + + // The following section should be used to demo sockets + //options.ListenUnixSocket("/tmp/kestrel-test.sock"); + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseStartup(); + + if (string.Equals(Process.GetCurrentProcess().Id.ToString(), Environment.GetEnvironmentVariable("LISTEN_PID"))) + { + // Use libuv if activated by systemd, since that's currently the only transport that supports being passed a socket handle. + hostBuilder.UseLibuv(options => + { + // Uncomment the following line to change the default number of libuv threads for all endpoints. + // options.ThreadCount = 4; + }); + } + + return hostBuilder.Build().RunAsync(); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/samples/SystemdTestApp/SystemdTestApp.csproj b/src/Servers/Kestrel/samples/SystemdTestApp/SystemdTestApp.csproj new file mode 100644 index 0000000000..a9bd0733b2 --- /dev/null +++ b/src/Servers/Kestrel/samples/SystemdTestApp/SystemdTestApp.csproj @@ -0,0 +1,21 @@ + + + + netcoreapp2.1;netcoreapp2.0;net461 + false + true + + + + + + + + + + + PreserveNewest + + + + diff --git a/src/Servers/Kestrel/samples/SystemdTestApp/testCert.pfx b/src/Servers/Kestrel/samples/SystemdTestApp/testCert.pfx new file mode 100644 index 0000000000..7118908c2d Binary files /dev/null and b/src/Servers/Kestrel/samples/SystemdTestApp/testCert.pfx differ diff --git a/src/Servers/Kestrel/shared/src/ThrowHelper.cs b/src/Servers/Kestrel/shared/src/ThrowHelper.cs new file mode 100644 index 0000000000..da31583777 --- /dev/null +++ b/src/Servers/Kestrel/shared/src/ThrowHelper.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + internal class ThrowHelper + { + public static void ThrowArgumentOutOfRangeException(int sourceLength, int offset) + { + throw GetArgumentOutOfRangeException(sourceLength, offset); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static ArgumentOutOfRangeException GetArgumentOutOfRangeException(int sourceLength, int offset) + { + if ((uint)offset > (uint)sourceLength) + { + // Offset is negative or less than array length + return new ArgumentOutOfRangeException(GetArgumentName(ExceptionArgument.offset)); + } + + // The third parameter (not passed) length must be out of range + return new ArgumentOutOfRangeException(GetArgumentName(ExceptionArgument.length)); + } + + public static void ThrowArgumentOutOfRangeException(ExceptionArgument argument) + { + throw GetArgumentOutOfRangeException(argument); + } + + public static void ThrowInvalidOperationException_ReferenceCountZero() + { + throw new InvalidOperationException("Can't release when reference count is already zero"); + } + + public static void ThrowInvalidOperationException_ReturningPinnedBlock() + { + throw new InvalidOperationException("Can't release when reference count is already zero"); + } + + public static void ThrowArgumentOutOfRangeException_BufferRequestTooLarge(int maxSize) + { + throw GetArgumentOutOfRangeException_BufferRequestTooLarge(maxSize); + } + + public static void ThrowObjectDisposedException(ExceptionArgument argument) + { + throw GetObjectDisposedException(argument); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static ArgumentOutOfRangeException GetArgumentOutOfRangeException(ExceptionArgument argument) + { + return new ArgumentOutOfRangeException(GetArgumentName(argument)); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static ArgumentOutOfRangeException GetArgumentOutOfRangeException_BufferRequestTooLarge(int maxSize) + { + return new ArgumentOutOfRangeException(GetArgumentName(ExceptionArgument.size), $"Cannot allocate more than {maxSize} bytes in a single buffer"); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static ObjectDisposedException GetObjectDisposedException(ExceptionArgument argument) + { + return new ObjectDisposedException(GetArgumentName(argument)); + } + + private static string GetArgumentName(ExceptionArgument argument) + { + Debug.Assert(Enum.IsDefined(typeof(ExceptionArgument), argument), "The enum value is not defined, please check the ExceptionArgument Enum."); + + return argument.ToString(); + } + } + + internal enum ExceptionArgument + { + size, + offset, + length, + MemoryPoolBlock + } +} diff --git a/src/Servers/Kestrel/shared/test/DisposableStack.cs b/src/Servers/Kestrel/shared/test/DisposableStack.cs new file mode 100644 index 0000000000..325fc6f8c8 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/DisposableStack.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tests +{ + public class DisposableStack : Stack, IDisposable + where T : IDisposable + { + public void Dispose() + { + while (Count > 0) + { + Pop()?.Dispose(); + } + } + } +} diff --git a/src/Servers/Kestrel/shared/test/DummyApplication.cs b/src/Servers/Kestrel/shared/test/DummyApplication.cs new file mode 100644 index 0000000000..389a799b2d --- /dev/null +++ b/src/Servers/Kestrel/shared/test/DummyApplication.cs @@ -0,0 +1,48 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Testing +{ + public class DummyApplication : IHttpApplication + { + private readonly RequestDelegate _requestDelegate; + private readonly IHttpContextFactory _httpContextFactory; + + public DummyApplication() + : this(_ => Task.CompletedTask) + { + } + + public DummyApplication(RequestDelegate requestDelegate) + : this(requestDelegate, null) + { + } + + public DummyApplication(RequestDelegate requestDelegate, IHttpContextFactory httpContextFactory) + { + _requestDelegate = requestDelegate; + _httpContextFactory = httpContextFactory; + } + + public HttpContext CreateContext(IFeatureCollection contextFeatures) + { + return _httpContextFactory?.Create(contextFeatures) ?? new DefaultHttpContext(contextFeatures); + } + + public void DisposeContext(HttpContext context, Exception exception) + { + _httpContextFactory?.Dispose(context); + } + + public async Task ProcessRequestAsync(HttpContext context) + { + await _requestDelegate(context); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/EventRaisingResourceCounter.cs b/src/Servers/Kestrel/shared/test/EventRaisingResourceCounter.cs new file mode 100644 index 0000000000..8f7d5efea3 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/EventRaisingResourceCounter.cs @@ -0,0 +1,34 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tests +{ + public class EventRaisingResourceCounter : ResourceCounter + { + private readonly ResourceCounter _wrapped; + + public EventRaisingResourceCounter(ResourceCounter wrapped) + { + _wrapped = wrapped; + } + + public event EventHandler OnRelease; + public event EventHandler OnLock; + + public override void ReleaseOne() + { + _wrapped.ReleaseOne(); + OnRelease?.Invoke(this, EventArgs.Empty); + } + + public override bool TryLockOne() + { + var retVal = _wrapped.TryLockOne(); + OnLock?.Invoke(this, retVal); + return retVal; + } + } +} diff --git a/src/Servers/Kestrel/shared/test/HttpParsingData.cs b/src/Servers/Kestrel/shared/test/HttpParsingData.cs new file mode 100644 index 0000000000..94ac2c8e39 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/HttpParsingData.cs @@ -0,0 +1,487 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Xunit; + +namespace Microsoft.AspNetCore.Testing +{ + public class HttpParsingData + { + public static IEnumerable RequestLineValidData + { + get + { + var methods = new[] + { + "GET", + "CUSTOM", + }; + var paths = new[] + { + Tuple.Create("/", "/"), + Tuple.Create("/abc", "/abc"), + Tuple.Create("/abc/de/f", "/abc/de/f"), + Tuple.Create("/%20", "/ "), + Tuple.Create("/a%20", "/a "), + Tuple.Create("/%20a", "/ a"), + Tuple.Create("/a/b%20c", "/a/b c"), + Tuple.Create("/%C3%A5", "/\u00E5"), + Tuple.Create("/a%C3%A5a", "/a\u00E5a"), + Tuple.Create("/%C3%A5/bc", "/\u00E5/bc"), + Tuple.Create("/%25", "/%"), + Tuple.Create("/%25%30%30", "/%00"), + Tuple.Create("/%%2000", "/% 00"), + Tuple.Create("/%2F", "/%2F"), + Tuple.Create("http://host/abs/path", "/abs/path"), + Tuple.Create("http://host/abs/path/", "/abs/path/"), + Tuple.Create("http://host/a%20b%20c/", "/a b c/"), + Tuple.Create("https://host/abs/path", "/abs/path"), + Tuple.Create("https://host/abs/path/", "/abs/path/"), + Tuple.Create("https://host:22/abs/path", "/abs/path"), + Tuple.Create("https://user@host:9080/abs/path", "/abs/path"), + Tuple.Create("http://host/", "/"), + Tuple.Create("http://host", "/"), + Tuple.Create("https://host/", "/"), + Tuple.Create("https://host", "/"), + Tuple.Create("http://user@host/", "/"), + Tuple.Create("http://127.0.0.1/", "/"), + Tuple.Create("http://user@127.0.0.1/", "/"), + Tuple.Create("http://user@127.0.0.1:8080/", "/"), + Tuple.Create("http://127.0.0.1:8080/", "/"), + Tuple.Create("http://[::1]", "/"), + Tuple.Create("http://[::1]/path", "/path"), + Tuple.Create("http://[::1]:8080/", "/"), + Tuple.Create("http://user@[::1]:8080/", "/"), + }; + var queryStrings = new[] + { + "", + "?", + "?arg1=val1", + "?arg1=a%20b", + "?%A", + "?%20=space", + "?%C3%A5=val", + "?path=/home", + "?path=/%C3%A5/", + "?question=what?", + "?%00", + "?arg=%00" + }; + var httpVersions = new[] + { + "HTTP/1.0", + "HTTP/1.1" + }; + + return from method in methods + from path in paths + from queryString in queryStrings + from httpVersion in httpVersions + select new[] + { + $"{method} {path.Item1}{queryString} {httpVersion}\r\n", + method, + $"{path.Item1}{queryString}", + $"{path.Item1}", + $"{path.Item2}", + queryString, + httpVersion + }; + } + } + + public static IEnumerable RequestLineDotSegmentData => new[] + { + new[] { "GET /a/../b HTTP/1.1\r\n", "/a/../b", "/b", "" }, + new[] { "GET /%61/../%62 HTTP/1.1\r\n", "/%61/../%62", "/b", "" }, + new[] { "GET /a/%2E%2E/b HTTP/1.1\r\n", "/a/%2E%2E/b", "/b", "" }, + new[] { "GET /%61/%2E%2E/%62 HTTP/1.1\r\n", "/%61/%2E%2E/%62", "/b", "" }, + new[] { "GET /a?p=/a/../b HTTP/1.1\r\n", "/a?p=/a/../b", "/a", "?p=/a/../b" }, + new[] { "GET /a?p=/a/%2E%2E/b HTTP/1.1\r\n", "/a?p=/a/%2E%2E/b", "/a", "?p=/a/%2E%2E/b" }, + new[] { "GET http://example.com/a/../b HTTP/1.1\r\n", "http://example.com/a/../b", "/b", "" }, + new[] { "GET http://example.com/%61/../%62 HTTP/1.1\r\n", "http://example.com/%61/../%62", "/b", "" }, + new[] { "GET http://example.com/a/%2E%2E/b HTTP/1.1\r\n", "http://example.com/a/%2E%2E/b", "/b", "" }, + new[] { "GET http://example.com/%61/%2E%2E/%62 HTTP/1.1\r\n", "http://example.com/%61/%2E%2E/%62", "/b", "" }, + new[] { "GET http://example.com/a?p=/a/../b HTTP/1.1\r\n", "http://example.com/a?p=/a/../b", "/a", "?p=/a/../b" }, + new[] { "GET http://example.com/a?p=/a/%2E%2E/b HTTP/1.1\r\n", "http://example.com/a?p=/a/%2E%2E/b", "/a", "?p=/a/%2E%2E/b" }, + new[] { "GET http://example.com?p=/a/../b HTTP/1.1\r\n", "http://example.com?p=/a/../b", "/", "?p=/a/../b" }, + new[] { "GET http://example.com?p=/a/%2E%2E/b HTTP/1.1\r\n", "http://example.com?p=/a/%2E%2E/b", "/", "?p=/a/%2E%2E/b" }, + + // Asterisk-form and authority-form should be unaffected and cause no issues + new[] { "OPTIONS * HTTP/1.1\r\n", "*", "", "" }, + new[] { "CONNECT www.example.com HTTP/1.1\r\n", "www.example.com", "", "" }, + }; + + public static IEnumerable RequestLineIncompleteData => new[] + { + "G", + "GE", + "GET", + "GET ", + "GET /", + "GET / ", + "GET / H", + "GET / HT", + "GET / HTT", + "GET / HTTP", + "GET / HTTP/", + "GET / HTTP/1", + "GET / HTTP/1.", + "GET / HTTP/1.1", + "GET / HTTP/1.1\r", + }; + + public static IEnumerable RequestLineInvalidData + { + get + { + return new[] + { + "G\r\n", + "GE\r\n", + "GET\r\n", + "GET \r\n", + "GET /\r\n", + "GET / \r\n", + "GET/HTTP/1.1\r\n", + "GET /HTTP/1.1\r\n", + " \r\n", + " \r\n", + "/ HTTP/1.1\r\n", + " / HTTP/1.1\r\n", + "/ \r\n", + "GET \r\n", + "GET HTTP/1.0\r\n", + "GET HTTP/1.1\r\n", + "GET / \n", + "GET / HTTP/1.0\n", + "GET / HTTP/1.1\n", + "GET / HTTP/1.0\rA\n", + "GET / HTTP/1.1\ra\n", + "GET? / HTTP/1.1\r\n", + "GET ? HTTP/1.1\r\n", + "GET /a?b=cHTTP/1.1\r\n", + "GET /a%20bHTTP/1.1\r\n", + "GET /a%20b?c=dHTTP/1.1\r\n", + "GET %2F HTTP/1.1\r\n", + "GET %00 HTTP/1.1\r\n", + "CUSTOM \r\n", + "CUSTOM /\r\n", + "CUSTOM / \r\n", + "CUSTOM /HTTP/1.1\r\n", + "CUSTOM \r\n", + "CUSTOM HTTP/1.0\r\n", + "CUSTOM HTTP/1.1\r\n", + "CUSTOM / \n", + "CUSTOM / HTTP/1.0\n", + "CUSTOM / HTTP/1.1\n", + "CUSTOM / HTTP/1.0\rA\n", + "CUSTOM / HTTP/1.1\ra\n", + "CUSTOM ? HTTP/1.1\r\n", + "CUSTOM /a?b=cHTTP/1.1\r\n", + "CUSTOM /a%20bHTTP/1.1\r\n", + "CUSTOM /a%20b?c=dHTTP/1.1\r\n", + "CUSTOM %2F HTTP/1.1\r\n", + "CUSTOM %00 HTTP/1.1\r\n", + }.Concat(MethodWithNonTokenCharData.Select(method => $"{method} / HTTP/1.0\r\n")); + } + } + + // Bad HTTP Methods (invalid according to RFC) + public static IEnumerable MethodWithNonTokenCharData + { + get + { + return new[] + { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + "\"", + "/", + "[", + "]", + "?", + "=", + "{", + "}", + "get@", + "post=", + "[0x00]" + }.Concat(MethodWithNullCharData); + } + } + + public static IEnumerable MethodWithNullCharData => new[] + { + // Bad HTTP Methods (invalid according to RFC) + "\0", + "\0GET", + "G\0T", + "GET\0", + }; + + public static IEnumerable TargetWithEncodedNullCharData => new[] + { + "/%00", + "/%00%00", + "/%E8%00%84", + "/%E8%85%00", + "/%F3%00%82%86", + "/%F3%85%00%82", + "/%F3%85%82%00", + }; + + public static TheoryData TargetInvalidData + { + get + { + var data = new TheoryData(); + + // Invalid absolute-form + data.Add("GET", "http://"); + data.Add("GET", "http:/"); + data.Add("GET", "https:/"); + data.Add("GET", "http:///"); + data.Add("GET", "https://"); + data.Add("GET", "http:////"); + data.Add("GET", "http://:80"); + data.Add("GET", "http://:80/abc"); + data.Add("GET", "http://user@"); + data.Add("GET", "http://user@/abc"); + data.Add("GET", "http://abc%20xyz/abc"); + data.Add("GET", "http://%20/abc?query=%0A"); + // Valid absolute-form but with unsupported schemes + data.Add("GET", "otherscheme://host/"); + data.Add("GET", "ws://host/"); + data.Add("GET", "wss://host/"); + // Must only have one asterisk + data.Add("OPTIONS", "**"); + // Relative form + data.Add("GET", "../../"); + data.Add("GET", "..\\."); + + return data; + } + } + + public static TheoryData MethodNotAllowedRequestLine + { + get + { + var methods = new[] + { + "GET", + "PUT", + "DELETE", + "POST", + "HEAD", + "TRACE", + "PATCH", + "CONNECT", + "OPTIONS", + "CUSTOM", + }; + + var data = new TheoryData(); + + foreach (var method in methods.Except(new[] { "OPTIONS" })) + { + data.Add($"{method} * HTTP/1.1\r\n", HttpMethod.Options); + } + + foreach (var method in methods.Except(new[] { "CONNECT" })) + { + data.Add($"{method} www.example.com:80 HTTP/1.1\r\n", HttpMethod.Connect); + } + + return data; + } + } + + public static IEnumerable TargetWithNullCharData + { + get + { + return new[] + { + "\0", + "/\0", + "/\0\0", + "/%C8\0", + }.Concat(QueryStringWithNullCharData); + } + } + + public static IEnumerable QueryStringWithNullCharData => new[] + { + "/?\0=a", + "/?a=\0", + }; + + public static TheoryData UnrecognizedHttpVersionData => new TheoryData + { + " ", + "/", + "H", + "HT", + "HTT", + "HTTP", + "HTTP/", + "HTTP/1", + "HTTP/1.", + "http/1.0", + "http/1.1", + "HTTP/1.1 ", + "HTTP/1.0a", + "HTTP/1.0ab", + "HTTP/1.1a", + "HTTP/1.1ab", + "HTTP/1.2", + "HTTP/3.0", + "hello", + "8charact", + }; + + public static IEnumerable RequestHeaderInvalidData => new[] + { + // Missing CR + new[] { "Header: value\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header: value\x0A") }, + new[] { "Header-1: value1\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: value1\x0A") }, + new[] { "Header-1: value1\r\nHeader-2: value2\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2: value2\x0A") }, + + // Line folding + new[] { "Header: line1\r\n line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" line2\x0D\x0A") }, + new[] { "Header: line1\r\n\tline2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09line2\x0D\x0A") }, + new[] { "Header: line1\r\n line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" line2\x0D\x0A") }, + new[] { "Header: line1\r\n \tline2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" \x09line2\x0D\x0A") }, + new[] { "Header: line1\r\n\t line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09 line2\x0D\x0A") }, + new[] { "Header: line1\r\n\t\tline2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09\x09line2\x0D\x0A") }, + new[] { "Header: line1\r\n \t\t line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" \x09\x09 line2\x0D\x0A") }, + new[] { "Header: line1\r\n \t \t line2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" \x09 \x09 line2\x0D\x0A") }, + new[] { "Header-1: multi\r\n line\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" line\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2: multi\r\n line\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" line\x0D\x0A") }, + new[] { "Header-1: value1\r\n Header-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" Header-2: value2\x0D\x0A") }, + new[] { "Header-1: value1\r\n\tHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09Header-2: value2\x0D\x0A") }, + + // CR in value + new[] { "Header-1: value1\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: value1\x0D\x0D\x0A") }, + new[] { "Header-1: val\rue1\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: val\x0Due1\x0D\x0A") }, + new[] { "Header-1: value1\rHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: value1\x0DHeader-2: value2\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2: value2\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2: value2\x0D\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2: v\ralue2\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2: v\x0Dalue2\x0D\x0A") }, + new[] { "Header-1: Value__\rVector16________Vector32\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value__\x0DVector16________Vector32\x0D\x0A") }, + new[] { "Header-1: Value___Vector16\r________Vector32\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16\x0D________Vector32\x0D\x0A") }, + new[] { "Header-1: Value___Vector16_______\rVector32\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16_______\x0DVector32\x0D\x0A") }, + new[] { "Header-1: Value___Vector16________Vector32\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16________Vector32\x0D\x0D\x0A") }, + new[] { "Header-1: Value___Vector16________Vector32_\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16________Vector32_\x0D\x0D\x0A") }, + new[] { "Header-1: Value___Vector16________Vector32Value___Vector16_______\rVector32\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16________Vector32Value___Vector16_______\x0DVector32\x0D\x0A") }, + new[] { "Header-1: Value___Vector16________Vector32Value___Vector16________Vector32\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16________Vector32Value___Vector16________Vector32\x0D\x0D\x0A") }, + new[] { "Header-1: Value___Vector16________Vector32Value___Vector16________Vector32_\r\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1: Value___Vector16________Vector32Value___Vector16________Vector32_\x0D\x0D\x0A") }, + + // Missing colon + new[] { "Header-1 value1\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1 value1\x0D\x0A") }, + new[] { "Header-1 value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-1 value1\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2 value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2 value2\x0D\x0A") }, + new[] { "\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x0A") }, + + // Starting with whitespace + new[] { " Header: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" Header: value\x0D\x0A") }, + new[] { "\tHeader: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09Header: value\x0D\x0A") }, + new[] { " Header-1: value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@" Header-1: value1\x0D\x0A") }, + new[] { "\tHeader-1: value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"\x09Header-1: value1\x0D\x0A") }, + + // Whitespace in header name + new[] { "Header : value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header : value\x0D\x0A") }, + new[] { "Header\t: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header\x09: value\x0D\x0A") }, + new[] { "Header\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header\x0D: value\x0D\x0A") }, + new[] { "Header_\rVector16: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header_\x0DVector16: value\x0D\x0A") }, + new[] { "Header__Vector16\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16\x0D: value\x0D\x0A") }, + new[] { "Header__Vector16_\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16_\x0D: value\x0D\x0A") }, + new[] { "Header_\rVector16________Vector32: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header_\x0DVector16________Vector32: value\x0D\x0A") }, + new[] { "Header__Vector16________Vector32\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16________Vector32\x0D: value\x0D\x0A") }, + new[] { "Header__Vector16________Vector32_\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16________Vector32_\x0D: value\x0D\x0A") }, + new[] { "Header__Vector16________Vector32Header_\rVector16________Vector32: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16________Vector32Header_\x0DVector16________Vector32: value\x0D\x0A") }, + new[] { "Header__Vector16________Vector32Header__Vector16________Vector32\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16________Vector32Header__Vector16________Vector32\x0D: value\x0D\x0A") }, + new[] { "Header__Vector16________Vector32Header__Vector16________Vector32_\r: value\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header__Vector16________Vector32Header__Vector16________Vector32_\x0D: value\x0D\x0A") }, + new[] { "Header 1: value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header 1: value1\x0D\x0A") }, + new[] { "Header 1 : value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header 1 : value1\x0D\x0A") }, + new[] { "Header 1\t: value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header 1\x09: value1\x0D\x0A") }, + new[] { "Header 1\r: value1\r\nHeader-2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header 1\x0D: value1\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader 2: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header 2: value2\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2 : value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2 : value2\x0D\x0A") }, + new[] { "Header-1: value1\r\nHeader-2\t: value2\r\n\r\n", CoreStrings.FormatBadRequest_InvalidRequestHeader_Detail(@"Header-2\x09: value2\x0D\x0A") }, + + // Headers not ending in CRLF line + new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r\r", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, + new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r ", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, + new[] { "Header-1: value1\r\nHeader-2: value2\r\n\r \n", CoreStrings.BadRequest_InvalidRequestHeadersNoCRLF }, + }; + + public static TheoryData HostHeaderData + => new TheoryData + { + { "OPTIONS *", "" }, + { "GET /pub/WWW/", "" }, + { "GET /pub/WWW/", " " }, + { "GET /pub/WWW/", "." }, + { "GET /pub/WWW/", "www.example.org" }, + { "GET http://localhost/", "localhost" }, + { "GET http://localhost:80/", "localhost:80" }, + { "GET https://localhost/", "localhost" }, + { "GET https://localhost:443/", "localhost:443" }, + { "CONNECT asp.net:80", "asp.net:80" }, + { "CONNECT asp.net:443", "asp.net:443" }, + }; + + public static TheoryData HostHeaderInvalidData + { + get + { + // see https://tools.ietf.org/html/rfc7230#section-5.4 + var invalidHostValues = new[] { + "", + " ", + "contoso.com:4000", + "contoso.com/", + "not-contoso.com", + "user@password:contoso.com", + "user@contoso.com", + "http://contoso.com/", + "http://contoso.com" + }; + + var data = new TheoryData(); + + foreach (var host in invalidHostValues) + { + // absolute form + // expected: GET http://contoso.com/ => Host: contoso.com + data.Add("GET http://contoso.com/", host); + + // authority-form + // expected: CONNECT contoso.com => Host: contoso.com + data.Add("CONNECT contoso.com", host); + } + + // port mismatch when target contains port + data.Add("GET https://contoso.com:443/", "contoso.com:5000"); + data.Add("CONNECT contoso.com:443", "contoso.com:5000"); + + return data; + } + } + } +} diff --git a/src/Servers/Kestrel/shared/test/KestrelTestLoggerProvider.cs b/src/Servers/Kestrel/shared/test/KestrelTestLoggerProvider.cs new file mode 100644 index 0000000000..48455557fb --- /dev/null +++ b/src/Servers/Kestrel/shared/test/KestrelTestLoggerProvider.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Testing +{ + public class KestrelTestLoggerProvider : ILoggerProvider + { + private readonly ILogger _testLogger; + + public KestrelTestLoggerProvider(bool throwOnCriticalErrors = true) + : this(new TestApplicationErrorLogger() { ThrowOnCriticalErrors = throwOnCriticalErrors }) + { + } + + public KestrelTestLoggerProvider(ILogger testLogger) + { + _testLogger = testLogger; + } + + public ILogger CreateLogger(string categoryName) + { + return _testLogger; + } + + public void Dispose() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/LifetimeNotImplemented.cs b/src/Servers/Kestrel/shared/test/LifetimeNotImplemented.cs new file mode 100644 index 0000000000..df5253cc62 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/LifetimeNotImplemented.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.AspNetCore.Hosting; + +namespace Microsoft.AspNetCore.Testing +{ + public class LifetimeNotImplemented : IApplicationLifetime + { + public CancellationToken ApplicationStarted + { + get + { + throw new NotImplementedException(); + } + } + + public CancellationToken ApplicationStopped + { + get + { + throw new NotImplementedException(); + } + } + + public CancellationToken ApplicationStopping + { + get + { + throw new NotImplementedException(); + } + } + + public void StopApplication() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/MockLogger.cs b/src/Servers/Kestrel/shared/test/MockLogger.cs new file mode 100644 index 0000000000..677d0687de --- /dev/null +++ b/src/Servers/Kestrel/shared/test/MockLogger.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions.Internal; + +namespace Microsoft.AspNetCore.Testing +{ + public class MockLogger : ILogger + { + private List _messages = new List(); + + public IDisposable BeginScope(TState state) + => NullScope.Instance; + + public bool IsEnabled(LogLevel logLevel) + => true; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { + _messages.Add(formatter(state, exception)); + } + + public IReadOnlyList Messages => _messages; + } +} diff --git a/src/Servers/Kestrel/shared/test/MockSystemClock.cs b/src/Servers/Kestrel/shared/test/MockSystemClock.cs new file mode 100644 index 0000000000..3977653630 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/MockSystemClock.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Testing +{ + public class MockSystemClock : ISystemClock + { + private long _utcNowTicks = DateTimeOffset.UtcNow.Ticks; + + public DateTimeOffset UtcNow + { + get + { + UtcNowCalled++; + return new DateTimeOffset(Interlocked.Read(ref _utcNowTicks), TimeSpan.Zero); + } + set + { + Interlocked.Exchange(ref _utcNowTicks, value.Ticks); + } + } + + public int UtcNowCalled { get; private set; } + } +} diff --git a/src/Servers/Kestrel/shared/test/PassThroughConnectionAdapter.cs b/src/Servers/Kestrel/shared/test/PassThroughConnectionAdapter.cs new file mode 100644 index 0000000000..cfad5c7e7c --- /dev/null +++ b/src/Servers/Kestrel/shared/test/PassThroughConnectionAdapter.cs @@ -0,0 +1,171 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; + +namespace Microsoft.AspNetCore.Testing +{ + public class PassThroughConnectionAdapter : IConnectionAdapter + { + public bool IsHttps => false; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + var adapted = new AdaptedConnection(new PassThroughStream(context.ConnectionStream)); + return Task.FromResult(adapted); + } + + private class AdaptedConnection : IAdaptedConnection + { + public AdaptedConnection(Stream stream) + { + ConnectionStream = stream; + } + + public Stream ConnectionStream { get; } + + public void Dispose() + { + } + } + + private class PassThroughStream : Stream + { + private readonly Stream _innerStream; + + public PassThroughStream(Stream innerStream) + { + _innerStream = innerStream; + } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override bool CanWrite => _innerStream.CanWrite; + + public override long Length => _innerStream.Length; + + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + + public override int ReadTimeout { get => _innerStream.ReadTimeout; set => _innerStream.ReadTimeout = value; } + + public override int WriteTimeout { get => _innerStream.WriteTimeout; set => _innerStream.WriteTimeout = value; } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _innerStream.ReadByte(); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _innerStream.EndRead(asyncResult); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + + + public override void WriteByte(byte value) + { + _innerStream.WriteByte(value); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _innerStream.EndWrite(asyncResult); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override void Flush() + { + _innerStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(); + + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _innerStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override void Close() + { + _innerStream.Close(); + } + +#if NETCOREAPP2_1 + public override int Read(Span buffer) + { + return _innerStream.Read(buffer); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override void Write(ReadOnlySpan buffer) + { + _innerStream.Write(buffer); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.WriteAsync(buffer, cancellationToken); + } + + public override void CopyTo(Stream destination, int bufferSize) + { + _innerStream.CopyTo(destination, bufferSize); + } +#endif + } + } +} diff --git a/src/Servers/Kestrel/shared/test/StringExtensions.cs b/src/Servers/Kestrel/shared/test/StringExtensions.cs new file mode 100644 index 0000000000..5d1756b55b --- /dev/null +++ b/src/Servers/Kestrel/shared/test/StringExtensions.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Testing +{ + public static class StringExtensions + { + public static string EscapeNonPrintable(this string s) + { + var ellipsis = s.Length > 128 + ? "..." + : string.Empty; + return s.Substring(0, Math.Min(128, s.Length)) + .Replace("\r", @"\x0D") + .Replace("\n", @"\x0A") + .Replace("\0", @"\x00") + + ellipsis; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/shared/test/TaskTimeoutExtensions.cs b/src/Servers/Kestrel/shared/test/TaskTimeoutExtensions.cs new file mode 100644 index 0000000000..8e83a7a70e --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TaskTimeoutExtensions.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Testing; + +namespace System.Threading.Tasks +{ + public static class TaskTimeoutExtensions + { + public static Task DefaultTimeout(this Task task) + { + return task.TimeoutAfter(TestConstants.DefaultTimeout); + } + + public static Task DefaultTimeout(this Task task) + { + return task.TimeoutAfter(TestConstants.DefaultTimeout); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/TestApp.cs b/src/Servers/Kestrel/shared/test/TestApp.cs new file mode 100644 index 0000000000..ef4eca006b --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestApp.cs @@ -0,0 +1,48 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Testing +{ + public static class TestApp + { + public static async Task EchoApp(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + var buffer = new byte[httpContext.Request.ContentLength ?? 0]; + var bytesRead = 0; + + while (bytesRead < buffer.Length) + { + var count = await request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead); + bytesRead += count; + } + + if (buffer.Length > 0) + { + await response.Body.WriteAsync(buffer, 0, buffer.Length); + } + } + + public static async Task EchoAppChunked(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + var data = new MemoryStream(); + await request.Body.CopyToAsync(data); + var bytes = data.ToArray(); + + response.Headers["Content-Length"] = bytes.Length.ToString(); + await response.Body.WriteAsync(bytes, 0, bytes.Length); + } + + public static Task EmptyApp(HttpContext httpContext) + { + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/shared/test/TestApplicationErrorLogger.cs b/src/Servers/Kestrel/shared/test/TestApplicationErrorLogger.cs new file mode 100644 index 0000000000..f447a32b38 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestApplicationErrorLogger.cs @@ -0,0 +1,85 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestApplicationErrorLogger : ILogger + { + // Application errors are logged using 13 as the eventId. + private const int ApplicationErrorEventId = 13; + + public List IgnoredExceptions { get; } = new List(); + + public bool ThrowOnCriticalErrors { get; set; } = true; + + public ConcurrentQueue Messages { get; } = new ConcurrentQueue(); + + public ConcurrentQueue Scopes { get; } = new ConcurrentQueue(); + + public int TotalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Error); + + public int CriticalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Critical); + + public int ApplicationErrorsLogged => Messages.Count(message => message.EventId.Id == ApplicationErrorEventId); + + public IDisposable BeginScope(TState state) + { + Scopes.Enqueue(state); + + return new Disposable(() => { Scopes.TryDequeue(out _); }); + } + + public bool IsEnabled(LogLevel logLevel) + { + return true; + } + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { +#if true + if (logLevel == LogLevel.Critical && ThrowOnCriticalErrors) +#endif + { + var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}"; + + Console.WriteLine(log); + + if (logLevel == LogLevel.Critical && ThrowOnCriticalErrors && !IgnoredExceptions.Contains(exception.GetType())) + { + throw new Exception($"Unexpected critical error. {log}", exception); + } + } + + // Fail tests where not all the connections close during server shutdown. + if (eventId.Id == 21 && eventId.Name == nameof(KestrelTrace.NotAllConnectionsAborted)) + { + var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}"; + throw new Exception($"Shutdown failure. {log}"); + } + + Messages.Enqueue(new LogMessage + { + LogLevel = logLevel, + EventId = eventId, + Exception = exception, + Message = formatter(state, exception) + }); + } + + public class LogMessage + { + public LogLevel LogLevel { get; set; } + public EventId EventId { get; set; } + public Exception Exception { get; set; } + public string Message { get; set; } + } + } +} diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/aspnetdevcert.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/aspnetdevcert.pfx new file mode 100644 index 0000000000..e6eeeaa2e1 Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/aspnetdevcert.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.ini b/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.ini new file mode 100644 index 0000000000..e2f2d8ab74 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.ini @@ -0,0 +1,12 @@ +# See https://www.openssl.org/docs/man1.0.2/apps/req.html for details on file format + +[ req ] +prompt = no +distinguished_name = testdn + +[ testdn ] +commonName = testcertonly + +# see https://www.openssl.org/docs/man1.0.2/apps/x509v3_config.html +[ req_extensions ] +extendedKeyUsage = clientAuth diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.pfx new file mode 100644 index 0000000000..32c76a1928 Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/eku.client.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.ini b/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.ini new file mode 100644 index 0000000000..d6a6c118f6 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.ini @@ -0,0 +1,12 @@ +# See https://www.openssl.org/docs/man1.0.2/apps/req.html for details on file format + +[ req ] +prompt = no +distinguished_name = testdn + +[ testdn ] +commonName = testcertonly + +# see https://www.openssl.org/docs/man1.0.2/apps/x509v3_config.html +[ req_extensions ] +extendedKeyUsage = codeSigning diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.pfx new file mode 100644 index 0000000000..050dfd05fe Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/eku.code_signing.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.ini b/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.ini new file mode 100644 index 0000000000..128af9a6fb --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.ini @@ -0,0 +1,12 @@ +# See https://www.openssl.org/docs/man1.0.2/apps/req.html for details on file format + +[ req ] +prompt = no +distinguished_name = req_distinguished_name + +[ req_distinguished_name ] +commonName = testcertonly + +# see https://www.openssl.org/docs/man1.0.2/apps/x509v3_config.html +[ req_extensions ] +extendedKeyUsage = serverAuth,clientAuth diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.pfx new file mode 100644 index 0000000000..3bbbd9c0d4 Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/eku.multiple_usages.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.ini b/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.ini new file mode 100644 index 0000000000..a3f07ef543 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.ini @@ -0,0 +1,12 @@ +# See https://www.openssl.org/docs/man1.0.2/apps/req.html for details on file format + +[ req ] +prompt = no +distinguished_name = testdn + +[ testdn ] +commonName = testcertonly + +# see https://www.openssl.org/docs/man1.0.2/apps/x509v3_config.html +[ req_extensions ] +extendedKeyUsage = serverAuth diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.pfx new file mode 100644 index 0000000000..8ac3ad5bdb Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/eku.server.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/make-test-certs.sh b/src/Servers/Kestrel/shared/test/TestCertificates/make-test-certs.sh new file mode 100644 index 0000000000..e489408b1c --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/make-test-certs.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +# +# Should be obvious, but don't use the certs created here for anything real. This is just meant for our testing. +# + +set -euo pipefail + +__machine_has() { + hash "$1" > /dev/null 2>&1 + return $? +} + +# +# Main +# + +if ! __machine_has openssl; then + echo 'OpenSSL is required to create the test certificates.' 1>&2 + exit 1 +fi + +# See https://www.openssl.org/docs/man1.0.2/apps/x509.html for more details on the openssl conf file + +if [[ $# == 0 ]]; then + echo "Usage: ${BASH_SOURCE[0]} ..." + echo "" + echo "Arguments:" + echo " Multiple allowed. Path to the *.ini file that configures a cert." +fi + +# loop over all arguments +while [[ $# > 0 ]]; do + # bashism for trimming the extension + config=$1 + shift + cert_name="${config%.*}" + key="$cert_name.pem" + cert="$cert_name.crt" + pfx="$cert_name.pfx" + + echo "Creating cert $cert_name" + + # see https://www.openssl.org/docs/man1.0.2/apps/req.html + openssl req -x509 \ + -days 1 \ + -config $config \ + -nodes \ + -newkey rsa:2048 \ + -keyout $key \ + -extensions req_extensions \ + -out $cert + + # See https://www.openssl.org/docs/man1.0.2/apps/pkcs12.html + openssl pkcs12 -export \ + -in $cert \ + -inkey $key \ + -out $pfx \ + -password pass:testPassword # so secure ;) + + rm $key $cert +done diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.ini b/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.ini new file mode 100644 index 0000000000..df234f06a6 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.ini @@ -0,0 +1,13 @@ +# See https://www.openssl.org/docs/man1.0.2/apps/req.html for details on file format + +[ req ] +prompt = no +distinguished_name = testdn + +[ testdn ] +commonName = testcertonly + +# see https://www.openssl.org/docs/man1.0.2/apps/x509v3_config.html +[ req_extensions ] +# keyUsages = +# extendedKeyUsage = diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.pfx new file mode 100644 index 0000000000..b4be4b5eda Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/no_extensions.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestCertificates/testCert.pfx b/src/Servers/Kestrel/shared/test/TestCertificates/testCert.pfx new file mode 100644 index 0000000000..7118908c2d Binary files /dev/null and b/src/Servers/Kestrel/shared/test/TestCertificates/testCert.pfx differ diff --git a/src/Servers/Kestrel/shared/test/TestConnection.cs b/src/Servers/Kestrel/shared/test/TestConnection.cs new file mode 100644 index 0000000000..6469dd2d19 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestConnection.cs @@ -0,0 +1,277 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Testing +{ + /// + /// Summary description for TestConnection + /// + public class TestConnection : IDisposable + { + private static readonly TimeSpan Timeout = TimeSpan.FromMinutes(1); + + private readonly bool _ownsSocket; + private readonly Socket _socket; + private readonly NetworkStream _stream; + private readonly StreamReader _reader; + + public TestConnection(int port) + : this(port, AddressFamily.InterNetwork) + { + } + + public TestConnection(int port, AddressFamily addressFamily) + : this(CreateConnectedLoopbackSocket(port, addressFamily), ownsSocket: true) + { + } + + public TestConnection(Socket socket) + : this(socket, ownsSocket: false) + { + } + + private TestConnection(Socket socket, bool ownsSocket) + { + _ownsSocket = ownsSocket; + _socket = socket; + _stream = new NetworkStream(_socket, ownsSocket: false); + _reader = new StreamReader(_stream, Encoding.ASCII); + } + + public Socket Socket => _socket; + + public Stream Stream => _stream; + + public StreamReader Reader => _reader; + + public void Dispose() + { + _stream.Dispose(); + + if (_ownsSocket) + { + _socket.Dispose(); + } + } + + public Task SendEmptyGet() + { + return Send("GET / HTTP/1.1", + "Host:", + "", + ""); + } + + public Task SendEmptyGetWithUpgradeAndKeepAlive() + => SendEmptyGetWithConnection("Upgrade, keep-alive"); + + public Task SendEmptyGetWithUpgrade() + => SendEmptyGetWithConnection("Upgrade"); + + public Task SendEmptyGetAsKeepAlive() + => SendEmptyGetWithConnection("keep-alive"); + + private Task SendEmptyGetWithConnection(string connection) + { + return Send("GET / HTTP/1.1", + "Host:", + "Connection: " + connection, + "", + ""); + } + + public async Task SendAll(params string[] lines) + { + var text = string.Join("\r\n", lines); + var writer = new StreamWriter(_stream, Encoding.GetEncoding("iso-8859-1")); + await writer.WriteAsync(text).ConfigureAwait(false); + await writer.FlushAsync().ConfigureAwait(false); + await _stream.FlushAsync().ConfigureAwait(false); + } + + public async Task Send(params string[] lines) + { + var text = string.Join("\r\n", lines); + var writer = new StreamWriter(_stream, Encoding.GetEncoding("iso-8859-1")); + for (var index = 0; index < text.Length; index++) + { + var ch = text[index]; + writer.Write(ch); + await writer.FlushAsync().ConfigureAwait(false); + // Re-add delay to help find socket input consumption bugs more consistently + //await Task.Delay(TimeSpan.FromMilliseconds(5)); + } + await writer.FlushAsync().ConfigureAwait(false); + await _stream.FlushAsync().ConfigureAwait(false); + } + + public async Task Receive(params string[] lines) + { + var expected = string.Join("\r\n", lines); + var actual = new char[expected.Length]; + var offset = 0; + + try + { + while (offset < expected.Length) + { + var data = new byte[expected.Length]; + var task = _reader.ReadAsync(actual, offset, actual.Length - offset); + if (!Debugger.IsAttached) + { + task = task.TimeoutAfter(Timeout); + } + var count = await task.ConfigureAwait(false); + if (count == 0) + { + break; + } + offset += count; + } + } + catch (TimeoutException ex) when (offset != 0) + { + throw new TimeoutException($"Did not receive a complete response within {Timeout}.{Environment.NewLine}{Environment.NewLine}" + + $"Expected:{Environment.NewLine}{expected}{Environment.NewLine}{Environment.NewLine}" + + $"Actual:{Environment.NewLine}{new string(actual, 0, offset)}{Environment.NewLine}", + ex); + } + + Assert.Equal(expected, new string(actual, 0, offset)); + } + + public Task ReceiveEnd(params string[] lines) + => ReceiveEnd(false, lines); + + public async Task ReceiveEnd(bool ignoreResponse, params string[] lines) + { + await Receive(lines).ConfigureAwait(false); + _socket.Shutdown(SocketShutdown.Send); + var ch = new char[128]; + var count = await _reader.ReadAsync(ch, 0, 128).TimeoutAfter(Timeout).ConfigureAwait(false); + if (!ignoreResponse) + { + var text = new string(ch, 0, count); + Assert.Equal("", text); + } + } + + public async Task ReceiveForcedEnd(params string[] lines) + { + await Receive(lines).ConfigureAwait(false); + + try + { + var ch = new char[128]; + var count = await _reader.ReadAsync(ch, 0, 128).TimeoutAfter(Timeout).ConfigureAwait(false); + var text = new string(ch, 0, count); + Assert.Equal("", text); + } + catch (IOException) + { + // The server is forcefully closing the connection so an IOException: + // "Unable to read data from the transport connection: An existing connection was forcibly closed by the remote host." + // isn't guaranteed but not unexpected. + } + } + + public async Task ReceiveStartsWith(string prefix, int maxLineLength = 1024) + { + var actual = new char[maxLineLength]; + var offset = 0; + + while (offset < maxLineLength) + { + // Read one char at a time so we don't read past the end of the line. + var task = _reader.ReadAsync(actual, offset, 1); + if (!Debugger.IsAttached) + { + Assert.True(task.Wait(4000), "timeout"); + } + var count = await task.ConfigureAwait(false); + if (count == 0) + { + break; + } + + Assert.True(count == 1); + offset++; + + if (actual[offset - 1] == '\n') + { + break; + } + } + + var actualLine = new string(actual, 0, offset); + Assert.StartsWith(prefix, actualLine); + } + + public void Shutdown(SocketShutdown how) + { + _socket.Shutdown(how); + } + + public void Reset() + { + _socket.LingerState = new LingerOption(true, 0); + _socket.Dispose(); + } + + public Task WaitForConnectionClose() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var eventArgs = new SocketAsyncEventArgs(); + eventArgs.SetBuffer(new byte[128], 0, 128); + eventArgs.Completed += ReceiveAsyncCompleted; + eventArgs.UserToken = tcs; + + if (!_socket.ReceiveAsync(eventArgs)) + { + ReceiveAsyncCompleted(this, eventArgs); + } + + return tcs.Task; + } + + private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) + { + var tcs = (TaskCompletionSource)e.UserToken; + if (e.BytesTransferred == 0) + { + tcs.SetResult(null); + } + else + { + tcs.SetException(new IOException( + $"Expected connection close, received data instead: \"{_reader.CurrentEncoding.GetString(e.Buffer, 0, e.BytesTransferred)}\"")); + } + } + + public static Socket CreateConnectedLoopbackSocket(int port) => CreateConnectedLoopbackSocket(port, AddressFamily.InterNetwork); + + public static Socket CreateConnectedLoopbackSocket(int port, AddressFamily addressFamily) + { + if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6) + { + throw new ArgumentException($"TestConnection does not support address family of type {addressFamily}", nameof(addressFamily)); + } + + var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); + var address = addressFamily == AddressFamily.InterNetworkV6 + ? IPAddress.IPv6Loopback + : IPAddress.Loopback; + socket.Connect(new IPEndPoint(address, port)); + return socket; + } + } +} diff --git a/src/Servers/Kestrel/shared/test/TestConstants.cs b/src/Servers/Kestrel/shared/test/TestConstants.cs new file mode 100644 index 0000000000..9615b4c5ef --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestConstants.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestConstants + { + public const int EOF = -4095; + public static TimeSpan DefaultTimeout = TimeSpan.FromSeconds(30); + } +} diff --git a/src/Servers/Kestrel/shared/test/TestHttp1Connection.cs b/src/Servers/Kestrel/shared/test/TestHttp1Connection.cs new file mode 100644 index 0000000000..f53a2a52db --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestHttp1Connection.cs @@ -0,0 +1,40 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestHttp1Connection : Http1Connection + { + public TestHttp1Connection(Http1ConnectionContext context) + : base(context) + { + } + + public HttpVersion HttpVersionEnum + { + get => _httpVersion; + set => _httpVersion = value; + } + + public bool KeepAlive + { + get => _keepAlive; + set => _keepAlive = value; + } + + public MessageBody NextMessageBody { private get; set; } + + public Task ProduceEndAsync() + { + return ProduceEnd(); + } + + protected override MessageBody CreateMessageBody() + { + return NextMessageBody ?? base.CreateMessageBody(); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/TestKestrelTrace.cs b/src/Servers/Kestrel/shared/test/TestKestrelTrace.cs new file mode 100644 index 0000000000..8e0a469dc7 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestKestrelTrace.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestKestrelTrace : KestrelTrace + { + public TestKestrelTrace() : this(new TestApplicationErrorLogger()) + { + } + + public TestKestrelTrace(TestApplicationErrorLogger testLogger) : base(testLogger) + { + Logger = testLogger; + } + + public TestApplicationErrorLogger Logger { get; private set; } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/shared/test/TestResources.cs b/src/Servers/Kestrel/shared/test/TestResources.cs new file mode 100644 index 0000000000..3218a1eaca --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestResources.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Security.Cryptography.X509Certificates; + +namespace Microsoft.AspNetCore.Testing +{ + public static class TestResources + { + private static readonly string _baseDir = Directory.GetCurrentDirectory(); + + public static string TestCertificatePath { get; } = Path.Combine(_baseDir, "testCert.pfx"); + public static string GetCertPath(string name) => Path.Combine(_baseDir, name); + + public static X509Certificate2 GetTestCertificate() + { + return new X509Certificate2(TestCertificatePath, "testPassword"); + } + + public static X509Certificate2 GetTestCertificate(string certName) + { + return new X509Certificate2(GetCertPath(certName), "testPassword"); + } + } +} diff --git a/src/Servers/Kestrel/shared/test/TestServiceContext.cs b/src/Servers/Kestrel/shared/test/TestServiceContext.cs new file mode 100644 index 0000000000..d662ee9686 --- /dev/null +++ b/src/Servers/Kestrel/shared/test/TestServiceContext.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestServiceContext : ServiceContext + { + public TestServiceContext() + { + var logger = new TestApplicationErrorLogger(); + var kestrelTrace = new TestKestrelTrace(logger); + var loggerFactory = new LoggerFactory(new[] { new KestrelTestLoggerProvider(logger) }); + + Initialize(loggerFactory, kestrelTrace); + } + + public TestServiceContext(ILoggerFactory loggerFactory) + : this(loggerFactory, new KestrelTrace(loggerFactory.CreateLogger("Microsoft.AspNetCore.Server.Kestrel"))) + { + } + + public TestServiceContext(ILoggerFactory loggerFactory, IKestrelTrace kestrelTrace) + { + Initialize(loggerFactory, kestrelTrace); + } + + private void Initialize(ILoggerFactory loggerFactory, IKestrelTrace kestrelTrace) + { + LoggerFactory = loggerFactory; + Log = kestrelTrace; + Scheduler = PipeScheduler.ThreadPool; + SystemClock = new MockSystemClock(); + DateHeaderValueManager = new DateHeaderValueManager(SystemClock); + ConnectionManager = new HttpConnectionManager(Log, ResourceCounter.Unlimited); + HttpParser = new HttpParser(Log.IsEnabled(LogLevel.Information)); + ServerOptions = new KestrelServerOptions + { + AddServerHeader = false + }; + } + + public ILoggerFactory LoggerFactory { get; set; } + + public string DateHeaderValue => DateHeaderValueManager.GetDateHeaderValues().String; + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/AddressRegistrationTests.cs b/src/Servers/Kestrel/test/FunctionalTests/AddressRegistrationTests.cs new file mode 100644 index 0000000000..3843c6c532 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/AddressRegistrationTests.cs @@ -0,0 +1,1077 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Extensions; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class AddressRegistrationTests : TestApplicationErrorLoggerLoggedTest + { + private const int MaxRetries = 10; + + [ConditionalFact] + [HostNameIsReachable] + public async Task RegisterAddresses_HostName_Success() + { + var hostName = Dns.GetHostName(); + await RegisterAddresses_Success($"http://{hostName}:0", $"http://{hostName}"); + } + + [Theory] + [MemberData(nameof(AddressRegistrationDataIPv4))] + public async Task RegisterAddresses_IPv4_Success(string addressInput, string testUrl) + { + await RegisterAddresses_Success(addressInput, testUrl); + } + + [ConditionalTheory] + [MemberData(nameof(AddressRegistrationDataIPv4Port5000Default))] + [PortSupportedCondition(5000)] + public async Task RegisterAddresses_IPv4Port5000Default_Success(string addressInput, string testUrl) + { + await RegisterAddresses_Success(addressInput, testUrl, 5000); + } + + [ConditionalTheory] + [MemberData(nameof(AddressRegistrationDataIPv4Port80))] + [PortSupportedCondition(80)] + public async Task RegisterAddresses_IPv4Port80_Success(string addressInput, string testUrl) + { + await RegisterAddresses_Success(addressInput, testUrl, 80); + } + + [Fact] + public async Task RegisterAddresses_IPv4StaticPort_Success() + { + await RegisterAddresses_StaticPort_Success("http://127.0.0.1", "http://127.0.0.1"); + } + + [Fact] + public async Task RegisterAddresses_IPv4LocalhostStaticPort_Success() + { + await RegisterAddresses_StaticPort_Success("http://localhost", "http://127.0.0.1"); + } + + [Fact] + public async Task RegisterIPEndPoint_IPv4StaticPort_Success() + { + await RegisterIPEndPoint_StaticPort_Success(IPAddress.Loopback, $"http://127.0.0.1"); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task RegisterIPEndPoint_IPv6StaticPort_Success() + { + await RegisterIPEndPoint_StaticPort_Success(IPAddress.IPv6Loopback, $"http://[::1]"); + } + + [ConditionalTheory] + [MemberData(nameof(IPEndPointRegistrationDataDynamicPort))] + [IPv6SupportedCondition] + public async Task RegisterIPEndPoint_DynamicPort_Success(IPEndPoint endPoint, string testUrl) + { + await RegisterIPEndPoint_Success(endPoint, testUrl); + } + + [ConditionalTheory] + [MemberData(nameof(IPEndPointRegistrationDataPort443))] + [IPv6SupportedCondition] + [PortSupportedCondition(443)] + public async Task RegisterIPEndPoint_Port443_Success(IPEndPoint endpoint, string testUrl) + { + await RegisterIPEndPoint_Success(endpoint, testUrl, 443); + } + + [ConditionalTheory(Skip="https://github.com/aspnet/KestrelHttpServer/issues/2434")] + [MemberData(nameof(AddressRegistrationDataIPv6))] + [IPv6SupportedCondition] + public async Task RegisterAddresses_IPv6_Success(string addressInput, string[] testUrls) + { + await RegisterAddresses_Success(addressInput, testUrls); + } + + [ConditionalTheory] + [MemberData(nameof(AddressRegistrationDataIPv6Port5000Default))] + [IPv6SupportedCondition] + [PortSupportedCondition(5000)] + public async Task RegisterAddresses_IPv6Port5000Default_Success(string addressInput, string[] testUrls) + { + await RegisterAddresses_Success(addressInput, testUrls); + } + + [ConditionalTheory] + [MemberData(nameof(AddressRegistrationDataIPv6Port80))] + [IPv6SupportedCondition] + [PortSupportedCondition(80)] + public async Task RegisterAddresses_IPv6Port80_Success(string addressInput, string[] testUrls) + { + await RegisterAddresses_Success(addressInput, testUrls); + } + + [ConditionalTheory] + [MemberData(nameof(AddressRegistrationDataIPv6ScopeId))] + [IPv6SupportedCondition] + [IPv6ScopeIdPresentCondition] + public async Task RegisterAddresses_IPv6ScopeId_Success(string addressInput, string testUrl) + { + await RegisterAddresses_Success(addressInput, testUrl); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task RegisterAddresses_IPv6StaticPort_Success() + { + await RegisterAddresses_StaticPort_Success("http://[::1]", "http://[::1]"); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task RegisterAddresses_IPv6LocalhostStaticPort_Success() + { + await RegisterAddresses_StaticPort_Success("http://localhost", new[] { "http://localhost", "http://127.0.0.1", "http://[::1]" }); + } + + private async Task RegisterAddresses_Success(string addressInput, string[] testUrls, int testPort = 0) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .ConfigureServices(AddTestLogging) + .UseUrls(addressInput) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + foreach (var testUrl in testUrls.Select(testUrl => $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}")) + { + var response = await HttpClientSlim.GetStringAsync(testUrl, validateCertificate: false); + + // Compare the response with Uri.ToString(), rather than testUrl directly. + // Required to handle IPv6 addresses with zone index, like "fe80::3%1" + Assert.Equal(new Uri(testUrl).ToString(), response); + } + } + } + + private Task RegisterAddresses_Success(string addressInput, string testUrl, int testPort = 0) + => RegisterAddresses_Success(addressInput, new[] { testUrl }, testPort); + + private Task RegisterAddresses_StaticPort_Success(string addressInput, string[] testUrls) => + RunTestWithStaticPort(port => RegisterAddresses_Success($"{addressInput}:{port}", testUrls, port)); + + [Fact] + public async Task RegisterHttpAddress_UpradedToHttpsByConfigureEndpointDefaults() + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(serverOptions => + { + serverOptions.ConfigureEndpointDefaults(listenOptions => + { + listenOptions.UseHttps(TestResources.GetTestCertificate()); + }); + }) + .ConfigureServices(AddTestLogging) + .UseUrls("http://127.0.0.1:0") + .Configure(app => + { + var serverAddresses = app.ServerFeatures.Get(); + app.Run(context => + { + Assert.Single(serverAddresses.Addresses); + return context.Response.WriteAsync(serverAddresses.Addresses.First()); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + var expectedUrl = $"https://127.0.0.1:{host.GetPort()}"; + var response = await HttpClientSlim.GetStringAsync(expectedUrl, validateCertificate: false); + + Assert.Equal(expectedUrl, response); + } + } + + private async Task RunTestWithStaticPort(Func test) + { + var retryCount = 0; + var errors = new List(); + + while (retryCount < MaxRetries) + { + try + { + var port = GetNextPort(); + await test(port); + return; + } + catch (XunitException) + { + throw; + } + catch (Exception ex) + { + errors.Add(ex); + } + + retryCount++; + } + + if (errors.Any()) + { + throw new AggregateException(errors); + } + } + + private Task RegisterAddresses_StaticPort_Success(string addressInput, string testUrl) + => RegisterAddresses_StaticPort_Success(addressInput, new[] { testUrl }); + + private async Task RegisterIPEndPoint_Success(IPEndPoint endPoint, string testUrl, int testPort = 0) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(endPoint, listenOptions => + { + if (testUrl.StartsWith("https")) + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + } + }); + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + var testUrlWithPort = $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}"; + + var options = ((IOptions)host.Services.GetService(typeof(IOptions))).Value; + Assert.Single(options.ListenOptions); + + var response = await HttpClientSlim.GetStringAsync(testUrlWithPort, validateCertificate: false); + + // Compare the response with Uri.ToString(), rather than testUrl directly. + // Required to handle IPv6 addresses with zone index, like "fe80::3%1" + Assert.Equal(new Uri(testUrlWithPort).ToString(), response); + } + } + + private Task RegisterIPEndPoint_StaticPort_Success(IPAddress address, string testUrl) + => RunTestWithStaticPort(port => RegisterIPEndPoint_Success(new IPEndPoint(address, port), testUrl, port)); + + [ConditionalFact] + public async Task ListenAnyIP_IPv4_Success() + { + await ListenAnyIP_Success(new[] { "http://localhost", "http://127.0.0.1" }); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task ListenAnyIP_IPv6_Success() + { + await ListenAnyIP_Success(new[] { "http://[::1]", "http://localhost", "http://127.0.0.1" }); + } + + [ConditionalFact] + [HostNameIsReachable] + public async Task ListenAnyIP_HostName_Success() + { + var hostName = Dns.GetHostName(); + await ListenAnyIP_Success(new[] { $"http://{hostName}" }); + } + + private async Task ListenAnyIP_Success(string[] testUrls, int testPort = 0) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.ListenAnyIP(testPort); + }) + .ConfigureServices(AddTestLogging) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + foreach (var testUrl in testUrls.Select(testUrl => $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}")) + { + var response = await HttpClientSlim.GetStringAsync(testUrl, validateCertificate: false); + + // Compare the response with Uri.ToString(), rather than testUrl directly. + // Required to handle IPv6 addresses with zone index, like "fe80::3%1" + Assert.Equal(new Uri(testUrl).ToString(), response); + } + } + } + + [ConditionalFact] + public async Task ListenLocalhost_IPv4LocalhostStaticPort_Success() + { + await ListenLocalhost_StaticPort_Success(new[] { "http://localhost", "http://127.0.0.1" }); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task ListenLocalhost_IPv6LocalhostStaticPort_Success() + { + await ListenLocalhost_StaticPort_Success(new[] { "http://localhost", "http://127.0.0.1", "http://[::1]" }); + } + + private Task ListenLocalhost_StaticPort_Success(string[] testUrls) => + RunTestWithStaticPort(port => ListenLocalhost_Success(testUrls, port)); + + private async Task ListenLocalhost_Success(string[] testUrls, int testPort = 0) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.ListenLocalhost(testPort); + }) + .ConfigureServices(AddTestLogging) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + foreach (var testUrl in testUrls.Select(testUrl => $"{testUrl}:{(testPort == 0 ? host.GetPort() : testPort)}")) + { + var response = await HttpClientSlim.GetStringAsync(testUrl, validateCertificate: false); + + // Compare the response with Uri.ToString(), rather than testUrl directly. + // Required to handle IPv6 addresses with zone index, like "fe80::3%1" + Assert.Equal(new Uri(testUrl).ToString(), response); + } + } + } + + [ConditionalFact] + [PortSupportedCondition(5000)] + public Task DefaultsServerAddress_BindsToIPv4() + { + return RegisterDefaultServerAddresses_Success(new[] { "http://127.0.0.1:5000" }); + } + + [ConditionalFact] + [IPv6SupportedCondition] + [PortSupportedCondition(5000)] + public Task DefaultsServerAddress_BindsToIPv6() + { + return RegisterDefaultServerAddresses_Success(new[] { "http://127.0.0.1:5000", "http://[::1]:5000" }); + } + + [ConditionalFact] + [PortSupportedCondition(5000)] + [PortSupportedCondition(5001)] + public Task DefaultsServerAddress_BindsToIPv4WithHttps() + { + return RegisterDefaultServerAddresses_Success( + new[] { "http://127.0.0.1:5000", "https://127.0.0.1:5001" }, mockHttps: true); + } + + [ConditionalFact] + [IPv6SupportedCondition] + [PortSupportedCondition(5000)] + [PortSupportedCondition(5001)] + public Task DefaultsServerAddress_BindsToIPv6WithHttps() + { + return RegisterDefaultServerAddresses_Success(new[] { + "http://127.0.0.1:5000", "http://[::1]:5000", + "https://127.0.0.1:5001", "https://[::1]:5001"}, + mockHttps: true); + } + + private async Task RegisterDefaultServerAddresses_Success(IEnumerable addresses, bool mockHttps = false) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + if (mockHttps) + { + options.DefaultCertificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + } + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + Assert.Equal(5000, host.GetPort()); + + if (mockHttps) + { + Assert.Contains(5001, host.GetPorts()); + } + + Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Debug && + (string.Equals(CoreStrings.FormatBindingToDefaultAddresses(Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress), log.Message, StringComparison.Ordinal) + || string.Equals(CoreStrings.FormatBindingToDefaultAddress(Constants.DefaultServerAddress), log.Message, StringComparison.Ordinal))); + + foreach (var address in addresses) + { + Assert.Equal(new Uri(address).ToString(), await HttpClientSlim.GetStringAsync(address, validateCertificate: false)); + } + } + } + + [Fact] + public void ThrowsWhenBindingToIPv4AddressInUse() + { + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + socket.Listen(0); + var port = ((IPEndPoint)socket.LocalEndPoint).Port; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls($"http://127.0.0.1:{port}") + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + var exception = Assert.Throws(() => host.Start()); + Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse($"http://127.0.0.1:{port}"), exception.Message); + } + } + } + + [ConditionalFact] + [IPv6SupportedCondition] + public void ThrowsWhenBindingToIPv6AddressInUse() + { + TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(IOException)); + + using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Bind(new IPEndPoint(IPAddress.IPv6Loopback, 0)); + socket.Listen(0); + var port = ((IPEndPoint)socket.LocalEndPoint).Port; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel() + .UseUrls($"http://[::1]:{port}") + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + var exception = Assert.Throws(() => host.Start()); + Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse($"http://[::1]:{port}"), exception.Message); + } + } + } + + [Fact] + public async Task OverrideDirectConfigurationWithIServerAddressesFeature_Succeeds() + { + var useUrlsAddress = $"http://127.0.0.1:0"; + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .UseUrls(useUrlsAddress) + .PreferHostingUrls(true) + .ConfigureServices(AddTestLogging) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + var port = host.GetPort(); + + // If this isn't working properly, we'll get the HTTPS endpoint defined in UseKestrel + // instead of the HTTP endpoint defined in UseUrls. + var serverAddresses = host.ServerFeatures.Get().Addresses; + Assert.Equal(1, serverAddresses.Count); + var useUrlsAddressWithPort = $"http://127.0.0.1:{port}"; + Assert.Equal(serverAddresses.First(), useUrlsAddressWithPort); + + Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Information && + string.Equals(CoreStrings.FormatOverridingWithPreferHostingUrls(nameof(IServerAddressesFeature.PreferHostingUrls), useUrlsAddress), + log.Message, StringComparison.Ordinal)); + + Assert.Equal(new Uri(useUrlsAddressWithPort).ToString(), await HttpClientSlim.GetStringAsync(useUrlsAddressWithPort)); + } + } + + [Fact] + public async Task DoesNotOverrideDirectConfigurationWithIServerAddressesFeature_IfPreferHostingUrlsFalse() + { + var useUrlsAddress = $"http://127.0.0.1:0"; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .UseUrls($"http://127.0.0.1:0") + .PreferHostingUrls(false) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + var port = host.GetPort(); + + // If this isn't working properly, we'll get the HTTP endpoint defined in UseUrls + // instead of the HTTPS endpoint defined in UseKestrel. + var serverAddresses = host.ServerFeatures.Get().Addresses; + Assert.Equal(1, serverAddresses.Count); + var endPointAddress = $"https://127.0.0.1:{port}"; + Assert.Equal(serverAddresses.First(), endPointAddress); + + Assert.Single(TestApplicationErrorLogger.Messages, log => log.LogLevel == LogLevel.Warning && + string.Equals(CoreStrings.FormatOverridingWithKestrelOptions(useUrlsAddress, "UseKestrel()"), + log.Message, StringComparison.Ordinal)); + + Assert.Equal(new Uri(endPointAddress).ToString(), await HttpClientSlim.GetStringAsync(endPointAddress, validateCertificate: false)); + } + } + + [Fact] + public async Task DoesNotOverrideDirectConfigurationWithIServerAddressesFeature_IfAddressesEmpty() + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .PreferHostingUrls(true) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + var port = host.GetPort(); + + // If this isn't working properly, we'll not get the HTTPS endpoint defined in UseKestrel. + var serverAddresses = host.ServerFeatures.Get().Addresses; + Assert.Equal(1, serverAddresses.Count); + var endPointAddress = $"https://127.0.0.1:{port}"; + Assert.Equal(serverAddresses.First(), endPointAddress); + + Assert.Equal(new Uri(endPointAddress).ToString(), await HttpClientSlim.GetStringAsync(endPointAddress, validateCertificate: false)); + } + } + + [Fact] + public void ThrowsWhenBindingLocalhostToIPv4AddressInUse() + { + ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily.InterNetwork); + } + + [ConditionalFact] + [IPv6SupportedCondition] + public void ThrowsWhenBindingLocalhostToIPv6AddressInUse() + { + ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily.InterNetworkV6); + } + + [Fact] + public void ThrowsWhenBindingLocalhostToDynamicPort() + { + TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(InvalidOperationException)); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel() + .UseUrls("http://localhost:0") + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + Assert.Throws(() => host.Start()); + } + } + + [Theory] + [InlineData("ftp://localhost")] + [InlineData("ssh://localhost")] + public void ThrowsForUnsupportedAddressFromHosting(string addr) + { + TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(InvalidOperationException)); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel() + .UseUrls(addr) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + Assert.Throws(() => host.Start()); + } + } + + [Fact] + public async Task CanRebindToEndPoint() + { + var port = GetNextPort(); + var endPointAddress = $"http://127.0.0.1:{port}/"; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, port); + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + Assert.Equal(endPointAddress, await HttpClientSlim.GetStringAsync(endPointAddress)); + } + + hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, port); + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + Assert.Equal(endPointAddress, await HttpClientSlim.GetStringAsync(endPointAddress)); + } + } + + [ConditionalFact] + [IPv6SupportedCondition] + public async Task CanRebindToMultipleEndPoints() + { + var port = GetNextPort(); + var ipv4endPointAddress = $"http://127.0.0.1:{port}/"; + var ipv6endPointAddress = $"http://[::1]:{port}/"; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, port); + options.Listen(IPAddress.IPv6Loopback, port); + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + Assert.Equal(ipv4endPointAddress, await HttpClientSlim.GetStringAsync(ipv4endPointAddress)); + Assert.Equal(ipv6endPointAddress, await HttpClientSlim.GetStringAsync(ipv6endPointAddress)); + } + + hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, port); + options.Listen(IPAddress.IPv6Loopback, port); + }) + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + Assert.Equal(ipv4endPointAddress, await HttpClientSlim.GetStringAsync(ipv4endPointAddress)); + Assert.Equal(ipv6endPointAddress, await HttpClientSlim.GetStringAsync(ipv6endPointAddress)); + } + } + + private void ThrowsWhenBindingLocalhostToAddressInUse(AddressFamily addressFamily) + { + TestApplicationErrorLogger.IgnoredExceptions.Add(typeof(IOException)); + + var addressInUseCount = 0; + var wrongMessageCount = 0; + + var address = addressFamily == AddressFamily.InterNetwork ? IPAddress.Loopback : IPAddress.IPv6Loopback; + var otherAddressFamily = addressFamily == AddressFamily.InterNetwork ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork; + + while (addressInUseCount < 10 && wrongMessageCount < 10) + { + int port; + + using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)) + { + // Bind first to IPv6Any to ensure both the IPv4 and IPv6 ports are avaiable. + socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0)); + socket.Listen(0); + port = ((IPEndPoint)socket.LocalEndPoint).Port; + } + + using (var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp)) + { + try + { + socket.Bind(new IPEndPoint(address, port)); + socket.Listen(0); + } + catch (SocketException) + { + addressInUseCount++; + continue; + } + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel() + .UseUrls($"http://localhost:{port}") + .Configure(ConfigureEchoAddress); + + using (var host = hostBuilder.Build()) + { + var exception = Assert.Throws(() => host.Start()); + + var thisAddressString = $"http://{(addressFamily == AddressFamily.InterNetwork ? "127.0.0.1" : "[::1]")}:{port}"; + var otherAddressString = $"http://{(addressFamily == AddressFamily.InterNetworkV6? "127.0.0.1" : "[::1]")}:{port}"; + + if (exception.Message == CoreStrings.FormatEndpointAlreadyInUse(otherAddressString)) + { + // Don't fail immediately, because it's possible that something else really did bind to the + // same port for the other address family between the IPv6Any bind above and now. + wrongMessageCount++; + continue; + } + + Assert.Equal(CoreStrings.FormatEndpointAlreadyInUse(thisAddressString), exception.Message); + break; + } + } + } + + if (addressInUseCount >= 10) + { + Assert.True(false, $"The corresponding {otherAddressFamily} address was already in use 10 times."); + } + + if (wrongMessageCount >= 10) + { + Assert.True(false, $"An error for a conflict with {otherAddressFamily} was thrown 10 times."); + } + } + + public static TheoryData AddressRegistrationDataIPv4 + { + get + { + var dataset = new TheoryData(); + + // Loopback + dataset.Add("http://127.0.0.1:0", "http://127.0.0.1"); + + // Any + dataset.Add("http://*:0/", "http://127.0.0.1"); + dataset.Add("http://+:0/", "http://127.0.0.1"); + + // Non-loopback addresses + var ipv4Addresses = GetIPAddresses() + .Where(ip => ip.AddressFamily == AddressFamily.InterNetwork) + .Where(ip => CanBindAndConnectToEndpoint(new IPEndPoint(ip, 0))); + + foreach (var ip in ipv4Addresses) + { + dataset.Add($"http://{ip}:0/", $"http://{ip}"); + } + + return dataset; + } + } + + public static TheoryData AddressRegistrationDataIPv4Port5000Default => + new TheoryData + { + { null, "http://127.0.0.1:5000/" }, + { string.Empty, "http://127.0.0.1:5000/" } + }; + + public static TheoryData IPEndPointRegistrationDataDynamicPort + { + get + { + var dataset = new TheoryData(); + + // Loopback + dataset.Add(new IPEndPoint(IPAddress.Loopback, 0), "http://127.0.0.1"); + dataset.Add(new IPEndPoint(IPAddress.Loopback, 0), "https://127.0.0.1"); + + // IPv6 loopback + dataset.Add(new IPEndPoint(IPAddress.IPv6Loopback, 0), "http://[::1]"); + dataset.Add(new IPEndPoint(IPAddress.IPv6Loopback, 0), "https://[::1]"); + + // Any + dataset.Add(new IPEndPoint(IPAddress.Any, 0), "http://127.0.0.1"); + dataset.Add(new IPEndPoint(IPAddress.Any, 0), "https://127.0.0.1"); + + // IPv6 Any + dataset.Add(new IPEndPoint(IPAddress.IPv6Any, 0), "http://127.0.0.1"); + dataset.Add(new IPEndPoint(IPAddress.IPv6Any, 0), "http://[::1]"); + dataset.Add(new IPEndPoint(IPAddress.IPv6Any, 0), "https://127.0.0.1"); + dataset.Add(new IPEndPoint(IPAddress.IPv6Any, 0), "https://[::1]"); + + // Non-loopback addresses + var ipv4Addresses = GetIPAddresses() + .Where(ip => ip.AddressFamily == AddressFamily.InterNetwork) + .Where(ip => CanBindAndConnectToEndpoint(new IPEndPoint(ip, 0))); + + foreach (var ip in ipv4Addresses) + { + dataset.Add(new IPEndPoint(ip, 0), $"http://{ip}"); + dataset.Add(new IPEndPoint(ip, 0), $"https://{ip}"); + } + + var ipv6Addresses = GetIPAddresses() + .Where(ip => ip.AddressFamily == AddressFamily.InterNetworkV6) + .Where(ip => ip.ScopeId == 0) + .Where(ip => CanBindAndConnectToEndpoint(new IPEndPoint(ip, 0))); + + foreach (var ip in ipv6Addresses) + { + dataset.Add(new IPEndPoint(ip, 0), $"http://[{ip}]"); + } + + return dataset; + } + } + + public static TheoryData AddressRegistrationDataIPv4Port80 => + new TheoryData + { + // Default port for HTTP (80) + { "http://127.0.0.1", "http://127.0.0.1" }, + { "http://localhost", "http://127.0.0.1" }, + { "http://*", "http://127.0.0.1" } + }; + + public static TheoryData IPEndPointRegistrationDataPort443 => + new TheoryData + { + + { new IPEndPoint(IPAddress.Loopback, 443), "https://127.0.0.1" }, + { new IPEndPoint(IPAddress.IPv6Loopback, 443), "https://[::1]" }, + { new IPEndPoint(IPAddress.Any, 443), "https://127.0.0.1" }, + { new IPEndPoint(IPAddress.IPv6Any, 443), "https://[::1]" } + }; + + public static TheoryData AddressRegistrationDataIPv6 + { + get + { + var dataset = new TheoryData(); + + // Loopback + dataset.Add($"http://[::1]:0/", new[] { $"http://[::1]" }); + + // Any + dataset.Add($"http://*:0/", new[] { $"http://127.0.0.1", $"http://[::1]" }); + dataset.Add($"http://+:0/", new[] { $"http://127.0.0.1", $"http://[::1]" }); + + // Non-loopback addresses + var ipv6Addresses = GetIPAddresses() + .Where(ip => !ip.Equals(IPAddress.IPv6Loopback)) + .Where(ip => ip.AddressFamily == AddressFamily.InterNetworkV6) + .Where(ip => ip.ScopeId == 0) + .Where(ip => CanBindAndConnectToEndpoint(new IPEndPoint(ip, 0))); + + foreach (var ip in ipv6Addresses) + { + dataset.Add($"http://[{ip}]:0/", new[] { $"http://[{ip}]" }); + } + + return dataset; + } + } + + public static TheoryData AddressRegistrationDataIPv6Port5000Default => + new TheoryData + { + { null, new[] { "http://127.0.0.1:5000/", "http://[::1]:5000/" } }, + { string.Empty, new[] { "http://127.0.0.1:5000/", "http://[::1]:5000/" } } + }; + + public static TheoryData AddressRegistrationDataIPv6Port80 => + new TheoryData + { + // Default port for HTTP (80) + { "http://[::1]", new[] { "http://[::1]/" } }, + { "http://localhost", new[] { "http://127.0.0.1/", "http://[::1]/" } }, + { "http://*", new[] { "http://[::1]/" } } + }; + + public static TheoryData AddressRegistrationDataIPv6ScopeId + { + get + { + var dataset = new TheoryData(); + + var ipv6Addresses = GetIPAddresses() + .Where(ip => ip.AddressFamily == AddressFamily.InterNetworkV6) + .Where(ip => ip.ScopeId != 0) + .Where(ip => CanBindAndConnectToEndpoint(new IPEndPoint(ip, 0))); + + foreach (var ip in ipv6Addresses) + { + dataset.Add($"http://[{ip}]:0/", $"http://[{ip}]"); + } + + // There may be no addresses with scope IDs and we need at least one data item in the + // collection, otherwise xUnit fails the test run because a theory has no data. + dataset.Add("http://[::1]:0", "http://[::1]"); + + return dataset; + } + } + + private static IEnumerable GetIPAddresses() + { + return NetworkInterface.GetAllNetworkInterfaces() + .Where(i => i.OperationalStatus == OperationalStatus.Up) + .SelectMany(i => i.GetIPProperties().UnicastAddresses) + .Select(a => a.Address); + } + + private void ConfigureEchoAddress(IApplicationBuilder app) + { + app.Run(context => + { + return context.Response.WriteAsync(context.Request.GetDisplayUrl()); + }); + } + + private static int GetNextPort() + { + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + // Let the OS assign the next available port. Unless we cycle through all ports + // on a test run, the OS will always increment the port number when making these calls. + // This prevents races in parallel test runs where a test is already bound to + // a given port, and a new test is able to bind to the same port due to port + // reuse being enabled by default by the OS. + socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)socket.LocalEndPoint).Port; + } + } + + private static bool CanBindAndConnectToEndpoint(IPEndPoint endPoint) + { + try + { + using (var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp)) + { + serverSocket.Bind(endPoint); + serverSocket.Listen(0); + + var socketArgs = new SocketAsyncEventArgs + { + RemoteEndPoint = serverSocket.LocalEndPoint + }; + + var mre = new ManualResetEventSlim(); + socketArgs.Completed += (s, e) => + { + mre.Set(); + e.ConnectSocket?.Dispose(); + }; + + // Connect can take a couple minutes to time out. + if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, socketArgs)) + { + return mre.Wait(5000) && socketArgs.SocketError == SocketError.Success; + } + else + { + socketArgs.ConnectSocket?.Dispose(); + return socketArgs.SocketError == SocketError.Success; + } + } + } + catch (SocketException) + { + return false; + } + } + + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + private class PortSupportedConditionAttribute : Attribute, ITestCondition + { + private readonly int _port; + private readonly Lazy _portSupported; + + public PortSupportedConditionAttribute(int port) + { + _port = port; + _portSupported = new Lazy(CanBindToPort); + } + + public bool IsMet => _portSupported.Value; + + public string SkipReason => $"Cannot bind to port {_port} on the host."; + + private bool CanBindToPort() + { + try + { + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Bind(new IPEndPoint(IPAddress.Loopback, _port)); + socket.Listen(0); + return true; + } + } + catch (SocketException) + { + return false; + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/BadHttpRequestTests.cs b/src/Servers/Kestrel/test/FunctionalTests/BadHttpRequestTests.cs new file mode 100644 index 0000000000..2954b3f7c2 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/BadHttpRequestTests.cs @@ -0,0 +1,262 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class BadHttpRequestTests : LoggedTest + { + [Theory] + [MemberData(nameof(InvalidRequestLineData))] + public Task TestInvalidRequestLines(string request, string expectedExceptionMessage) + { + return TestBadRequest( + request, + "400 Bad Request", + expectedExceptionMessage); + } + + [Theory] + [MemberData(nameof(UnrecognizedHttpVersionData))] + public Task TestInvalidRequestLinesWithUnrecognizedVersion(string httpVersion) + { + return TestBadRequest( + $"GET / {httpVersion}\r\n", + "505 HTTP Version Not Supported", + CoreStrings.FormatBadRequest_UnrecognizedHTTPVersion(httpVersion)); + } + + [Theory] + [MemberData(nameof(InvalidRequestHeaderData))] + public Task TestInvalidHeaders(string rawHeaders, string expectedExceptionMessage) + { + return TestBadRequest( + $"GET / HTTP/1.1\r\n{rawHeaders}", + "400 Bad Request", + expectedExceptionMessage); + } + + [Theory] + [InlineData("Hea\0der: value", "Invalid characters in header name.")] + [InlineData("Header: va\0lue", "Malformed request: invalid headers.")] + [InlineData("Head\x80r: value", "Invalid characters in header name.")] + [InlineData("Header: valu\x80", "Malformed request: invalid headers.")] + public Task BadRequestWhenHeaderNameContainsNonASCIIOrNullCharacters(string header, string expectedExceptionMessage) + { + return TestBadRequest( + $"GET / HTTP/1.1\r\n{header}\r\n\r\n", + "400 Bad Request", + expectedExceptionMessage); + } + + [Theory] + [InlineData("POST")] + [InlineData("PUT")] + public Task BadRequestIfMethodRequiresLengthButNoContentLengthOrTransferEncodingInRequest(string method) + { + return TestBadRequest( + $"{method} / HTTP/1.1\r\nHost:\r\n\r\n", + "411 Length Required", + CoreStrings.FormatBadRequest_LengthRequired(method)); + } + + [Theory] + [InlineData("POST")] + [InlineData("PUT")] + public Task BadRequestIfMethodRequiresLengthButNoContentLengthInHttp10Request(string method) + { + return TestBadRequest( + $"{method} / HTTP/1.0\r\n\r\n", + "400 Bad Request", + CoreStrings.FormatBadRequest_LengthRequiredHttp10(method)); + } + + [Theory] + [InlineData("NaN")] + [InlineData("-1")] + public Task BadRequestIfContentLengthInvalid(string contentLength) + { + return TestBadRequest( + $"POST / HTTP/1.1\r\nHost:\r\nContent-Length: {contentLength}\r\n\r\n", + "400 Bad Request", + CoreStrings.FormatBadRequest_InvalidContentLength_Detail(contentLength)); + } + + [Theory] + [InlineData("GET *", "OPTIONS")] + [InlineData("GET www.host.com", "CONNECT")] + public Task RejectsIncorrectMethods(string request, string allowedMethod) + { + return TestBadRequest( + $"{request} HTTP/1.1\r\n", + "405 Method Not Allowed", + CoreStrings.BadRequest_MethodNotAllowed, + $"Allow: {allowedMethod}"); + } + + [Fact] + public Task BadRequestIfHostHeaderMissing() + { + return TestBadRequest( + "GET / HTTP/1.1\r\n\r\n", + "400 Bad Request", + CoreStrings.BadRequest_MissingHostHeader); + } + + [Fact] + public Task BadRequestIfMultipleHostHeaders() + { + return TestBadRequest("GET / HTTP/1.1\r\nHost: localhost\r\nHost: localhost\r\n\r\n", + "400 Bad Request", + CoreStrings.BadRequest_MultipleHostHeaders); + } + + [Theory] + [MemberData(nameof(InvalidHostHeaderData))] + public Task BadRequestIfHostHeaderDoesNotMatchRequestTarget(string requestTarget, string host) + { + return TestBadRequest( + $"{requestTarget} HTTP/1.1\r\nHost: {host}\r\n\r\n", + "400 Bad Request", + CoreStrings.FormatBadRequest_InvalidHostHeader_Detail(host.Trim())); + } + + [Fact] + public Task BadRequestFor10BadHostHeaderFormat() + { + return TestBadRequest( + $"GET / HTTP/1.0\r\nHost: a=b\r\n\r\n", + "400 Bad Request", + CoreStrings.FormatBadRequest_InvalidHostHeader_Detail("a=b")); + } + + [Fact] + public Task BadRequestFor11BadHostHeaderFormat() + { + return TestBadRequest( + $"GET / HTTP/1.1\r\nHost: a=b\r\n\r\n", + "400 Bad Request", + CoreStrings.FormatBadRequest_InvalidHostHeader_Detail("a=b")); + } + + [Fact] + public async Task BadRequestLogsAreNotHigherThanInformation() + { + using (var server = new TestServer(async context => + { + await context.Request.Body.ReadAsync(new byte[1], 0, 1); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll( + "GET ? HTTP/1.1", + "", + ""); + await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); + } + } + + Assert.All(TestSink.Writes, w => Assert.InRange(w.LogLevel, LogLevel.Trace, LogLevel.Information)); + Assert.Contains(TestSink.Writes, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information); + } + + [Fact] + public async Task TestRequestSplitting() + { + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory, Mock.Of()))) + { + using (var client = server.CreateConnection()) + { + await client.SendAll( + "GET /\x0D\0x0ALocation:http://www.contoso.com/ HTTP/1.1", + "Host:\r\n\r\n"); + + await client.ReceiveStartsWith("HTTP/1.1 400"); + } + } + } + + private async Task TestBadRequest(string request, string expectedResponseStatusCode, string expectedExceptionMessage, string expectedAllowHeader = null) + { + BadHttpRequestException loggedException = null; + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + mockKestrelTrace + .Setup(trace => trace.ConnectionBadRequest(It.IsAny(), It.IsAny())) + .Callback((connectionId, exception) => loggedException = exception); + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory, mockKestrelTrace.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll(request); + await ReceiveBadRequestResponse(connection, expectedResponseStatusCode, server.Context.DateHeaderValue, expectedAllowHeader); + } + } + + mockKestrelTrace.Verify(trace => trace.ConnectionBadRequest(It.IsAny(), It.IsAny())); + Assert.Equal(expectedExceptionMessage, loggedException.Message); + } + + private async Task ReceiveBadRequestResponse(TestConnection connection, string expectedResponseStatusCode, string expectedDateHeaderValue, string expectedAllowHeader = null) + { + var lines = new[] + { + $"HTTP/1.1 {expectedResponseStatusCode}", + "Connection: close", + $"Date: {expectedDateHeaderValue}", + "Content-Length: 0", + expectedAllowHeader, + "", + "" + }; + + await connection.ReceiveForcedEnd(lines.Where(f => f != null).ToArray()); + } + + public static TheoryData InvalidRequestLineData + { + get + { + var data = new TheoryData(); + + foreach (var requestLine in HttpParsingData.RequestLineInvalidData) + { + data.Add(requestLine, CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(requestLine.EscapeNonPrintable())); + } + + foreach (var target in HttpParsingData.TargetWithEncodedNullCharData) + { + data.Add($"GET {target} HTTP/1.1\r\n", CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable())); + } + + foreach (var target in HttpParsingData.TargetWithNullCharData) + { + data.Add($"GET {target} HTTP/1.1\r\n", CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable())); + } + + return data; + } + } + + public static TheoryData UnrecognizedHttpVersionData => HttpParsingData.UnrecognizedHttpVersionData; + + public static IEnumerable InvalidRequestHeaderData => HttpParsingData.RequestHeaderInvalidData; + + public static TheoryData InvalidHostHeaderData => HttpParsingData.HostHeaderInvalidData; + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/CertificateLoaderTests.cs b/src/Servers/Kestrel/test/FunctionalTests/CertificateLoaderTests.cs new file mode 100644 index 0000000000..86cbaa288d --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/CertificateLoaderTests.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class CertificateLoaderTests : LoggedTest + { + [Theory] + [InlineData("no_extensions.pfx")] + public void IsCertificateAllowedForServerAuth_AllowWithNoExtensions(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.Empty(cert.Extensions.OfType()); + + Assert.True(CertificateLoader.IsCertificateAllowedForServerAuth(cert)); + } + + [Theory] + [InlineData("eku.server.pfx")] + [InlineData("eku.multiple_usages.pfx")] + public void IsCertificateAllowedForServerAuth_ValidatesEnhancedKeyUsageOnCertificate(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.NotEmpty(cert.Extensions); + var eku = Assert.Single(cert.Extensions.OfType()); + Assert.NotEmpty(eku.EnhancedKeyUsages); + + Assert.True(CertificateLoader.IsCertificateAllowedForServerAuth(cert)); + } + + [Theory] + [InlineData("eku.code_signing.pfx")] + [InlineData("eku.client.pfx")] + public void IsCertificateAllowedForServerAuth_RejectsCertificatesMissingServerEku(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.NotEmpty(cert.Extensions); + var eku = Assert.Single(cert.Extensions.OfType()); + Assert.NotEmpty(eku.EnhancedKeyUsages); + + Assert.False(CertificateLoader.IsCertificateAllowedForServerAuth(cert)); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/ChunkedRequestTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ChunkedRequestTests.cs new file mode 100644 index 0000000000..ecbb6f23b2 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/ChunkedRequestTests.cs @@ -0,0 +1,690 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class ChunkedRequestTests : LoggedTest + { + public static TheoryData ConnectionAdapterData => new TheoryData + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + } + }; + + private async Task App(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + while (true) + { + var buffer = new byte[8192]; + var count = await request.Body.ReadAsync(buffer, 0, buffer.Length); + if (count == 0) + { + break; + } + await response.Body.WriteAsync(buffer, 0, count); + } + } + + private async Task AppChunked(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + var data = new MemoryStream(); + await request.Body.CopyToAsync(data); + var bytes = data.ToArray(); + + response.Headers["Content-Length"] = bytes.Length.ToString(); + await response.Body.WriteAsync(bytes, 0, bytes.Length); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10TransferEncoding(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Host:", + "Transfer-Encoding: chunked", + "", + "5", "Hello", + "6", " World", + "0", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "", + "Hello World"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10KeepAliveTransferEncoding(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(AppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Host:", + "Transfer-Encoding: chunked", + "Connection: keep-alive", + "", + "5", "Hello", + "6", " World", + "0", + "", + "POST / HTTP/1.0", + "Content-Length: 7", + "", + "Goodbye"); + await connection.Receive( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 7", + "", + "Goodbye"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task RequestBodyIsConsumedAutomaticallyIfAppDoesntConsumeItFully(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + Assert.Equal("POST", request.Method); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "HelloPOST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", "HelloChunked", + "0", + "", + "POST / HTTP/1.1", + "Host:", + "Content-Length: 7", + "", + "Goodbye"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello WorldHTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello WorldHTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task TrailingHeadersAreParsed(ListenOptions listenOptions) + { + var requestCount = 10; + var requestsReceived = 0; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + var buffer = new byte[200]; + + while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) + { + ;// read to end + } + + if (requestsReceived < requestCount) + { + Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"].ToString()); + } + else + { + Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"])); + } + + requestsReceived++; + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, new TestServiceContext(LoggerFactory), listenOptions)) + { + var response = string.Join("\r\n", new string[] { + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"}); + + var expectedFullResponse = string.Join("", Enumerable.Repeat(response, requestCount + 1)); + + IEnumerable sendSequence = new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", + "HelloChunked", + "0", + ""}; + + for (var i = 1; i < requestCount; i++) + { + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", + $"HelloChunk{i:00}", + "0", + string.Concat("X-Trailer-Header: ", new string('a', i)), + "" }); + } + + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Content-Length: 7", + "", + "Goodbye" + }); + + var fullRequest = sendSequence.ToArray(); + + using (var connection = server.CreateConnection()) + { + await connection.Send(fullRequest); + await connection.ReceiveEnd(expectedFullResponse); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task TrailingHeadersCountTowardsHeadersTotalSizeLimit(ListenOptions listenOptions) + { + const string transferEncodingHeaderLine = "Transfer-Encoding: chunked"; + const string headerLine = "Header: value"; + const string trailingHeaderLine = "Trailing-Header: trailing-value"; + + var testContext = new TestServiceContext(LoggerFactory); + testContext.ServerOptions.Limits.MaxRequestHeadersTotalSize = + transferEncodingHeaderLine.Length + 2 + + headerLine.Length + 2 + + trailingHeaderLine.Length + 1; + + using (var server = new TestServer(async context => + { + var buffer = new byte[128]; + while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) ; // read to end + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + $"{transferEncodingHeaderLine}", + $"{headerLine}", + "", + "2", + "42", + "0", + $"{trailingHeaderLine}", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 431 Request Header Fields Too Large", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task TrailingHeadersCountTowardsHeaderCountLimit(ListenOptions listenOptions) + { + const string transferEncodingHeaderLine = "Transfer-Encoding: chunked"; + const string headerLine = "Header: value"; + const string trailingHeaderLine = "Trailing-Header: trailing-value"; + + var testContext = new TestServiceContext(LoggerFactory); + testContext.ServerOptions.Limits.MaxRequestHeaderCount = 2; + + using (var server = new TestServer(async context => + { + var buffer = new byte[128]; + while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) ; // read to end + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + $"{transferEncodingHeaderLine}", + $"{headerLine}", + "", + "2", + "42", + "0", + $"{trailingHeaderLine}", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 431 Request Header Fields Too Large", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ExtensionsAreIgnored(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var requestCount = 10; + var requestsReceived = 0; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + var buffer = new byte[200]; + + while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) + { + ;// read to end + } + + if (requestsReceived < requestCount) + { + Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"].ToString()); + } + else + { + Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"])); + } + + requestsReceived++; + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + var response = string.Join("\r\n", new string[] { + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"}); + + var expectedFullResponse = string.Join("", Enumerable.Repeat(response, requestCount + 1)); + + IEnumerable sendSequence = new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C;hello there", + "HelloChunked", + "0;hello there", + ""}; + + for (var i = 1; i < requestCount; i++) + { + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C;hello there", + $"HelloChunk{i:00}", + "0;hello there", + string.Concat("X-Trailer-Header: ", new string('a', i)), + "" }); + } + + sendSequence = sendSequence.Concat(new string[] { + "POST / HTTP/1.1", + "Host:", + "Content-Length: 7", + "", + "Goodbye" + }); + + var fullRequest = sendSequence.ToArray(); + + using (var connection = server.CreateConnection()) + { + await connection.Send(fullRequest); + await connection.ReceiveEnd(expectedFullResponse); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task InvalidLengthResultsIn400(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + var buffer = new byte[200]; + + while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) + { + ;// read to end + } + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "Cii"); + + await connection.Receive( + "HTTP/1.1 400 Bad Request", + "Connection: close", + ""); + await connection.ReceiveForcedEnd( + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task InvalidSizedDataResultsIn400(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + + var buffer = new byte[200]; + + while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) + { + ;// read to end + } + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "C", + "HelloChunkedIn"); + + await connection.Receive( + "HTTP/1.1 400 Bad Request", + "Connection: close", + ""); + await connection.ReceiveForcedEnd( + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ChunkedNotFinalTransferCodingResultsIn400(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + using (var server = new TestServer(httpContext => + { + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: not-chunked", + "", + "C", + "hello, world", + "0", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + // Content-Length should not affect this + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: not-chunked", + "Content-Length: 22", + "", + "C", + "hello, world", + "0", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked, not-chunked", + "", + "C", + "hello, world", + "0", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + // Content-Length should not affect this + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked, not-chunked", + "Content-Length: 22", + "", + "C", + "hello, world", + "0", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ClosingConnectionMidChunkPrefixThrows(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var readStartedTcs = new TaskCompletionSource(); + var exTcs = new TaskCompletionSource(); + + using (var server = new TestServer(async httpContext => + { + var readTask = httpContext.Request.Body.CopyToAsync(Stream.Null); + readStartedTcs.SetResult(null); + + try + { + await readTask; + } + catch (BadHttpRequestException badRequestEx) + { + exTcs.TrySetResult(badRequestEx); + } + catch (Exception ex) + { + exTcs.SetException(ex); + } + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1"); + + await readStartedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + + connection.Socket.Shutdown(SocketShutdown.Send); + + await connection.ReceiveEnd(ignoreResponse: true); + + var badReqEx = await exTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + Assert.Equal(RequestRejectionReason.UnexpectedEndOfRequestContent, badReqEx.Reason); + } + } + } + } +} + diff --git a/src/Servers/Kestrel/test/FunctionalTests/ChunkedResponseTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ChunkedResponseTests.cs new file mode 100644 index 0000000000..6cb1a9e439 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/ChunkedResponseTests.cs @@ -0,0 +1,434 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class ChunkedResponseTests : LoggedTest + { + public static TheoryData ConnectionAdapterData => new TheoryData + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + } + }; + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ResponsesAreChunkedAutomatically(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ResponsesAreNotChunkedAutomaticallyForHttp10Requests(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync("Hello "); + await httpContext.Response.WriteAsync("World!"); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "", + "Hello World!"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ResponsesAreChunkedAutomaticallyForHttp11NonKeepAliveRequests(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync("Hello "); + await httpContext.Response.WriteAsync("World!"); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "Connection: close", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task SettingConnectionCloseHeaderInAppDoesNotDisableChunking(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Connection"] = "close"; + await httpContext.Response.WriteAsync("Hello "); + await httpContext.Response.WriteAsync("World!"); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroLengthWritesAreIgnored(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); + await response.Body.WriteAsync(new byte[0], 0, 0); + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroLengthWritesFlushHeaders(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + var flushed = new SemaphoreSlim(0, 1); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.WriteAsync(""); + + await flushed.WaitAsync(); + + await response.WriteAsync("Hello World!"); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + ""); + + flushed.Release(); + + await connection.ReceiveEnd( + "c", + "Hello World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(new byte[0], 0, 0); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedIfExceptionThrownAfterWrite(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12); + throw new Exception(); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // SendEnd is not called, so it isn't the client closing the connection. + // client closing the connection. + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "Hello World!", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedIfExceptionThrownAfterZeroLengthWrite(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(new byte[0], 0, 0); + throw new Exception(); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // SendEnd is not called, so it isn't the client closing the connection. + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + + // Headers are sent before connection is closed, but chunked body terminator isn't sent + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task WritesAreFlushedPriorToResponseCompletion(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + var flushWh = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); + + // Don't complete response until client has received the first chunk. + await flushWh.Task.DefaultTimeout(); + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + ""); + + flushWh.SetResult(null); + + await connection.ReceiveEnd( + "6", + "World!", + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ChunksCanBeWrittenManually(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.Headers["Transfer-Encoding"] = "chunked"; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("6\r\nHello \r\n"), 0, 11); + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("6\r\nWorld!\r\n"), 0, 11); + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("0\r\n\r\n"), 0, 5); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + } + } + } +} + diff --git a/src/Servers/Kestrel/test/FunctionalTests/ConnectionAdapterTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ConnectionAdapterTests.cs new file mode 100644 index 0000000000..a4c719bbde --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/ConnectionAdapterTests.cs @@ -0,0 +1,356 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class ConnectionAdapterTests : LoggedTest + { + [Fact] + public async Task CanReadAndWriteWithRewritingConnectionAdapter() + { + var adapter = new RewritingConnectionAdapter(); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { adapter } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + var sendString = "POST / HTTP/1.0\r\nContent-Length: 12\r\n\r\nHello World?"; + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // "?" changes to "!" + await connection.Send(sendString); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {serviceContext.DateHeaderValue}", + "", + "Hello World!"); + } + } + + Assert.Equal(sendString.Length, adapter.BytesRead); + } + + [Fact] + public async Task CanReadAndWriteWithAsyncConnectionAdapter() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new AsyncConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 12", + "", + "Hello World?"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {serviceContext.DateHeaderValue}", + "", + "Hello World!"); + } + } + } + + [Fact] + public async Task ImmediateFinAfterOnConnectionAsyncClosesGracefully() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new AsyncConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // FIN + connection.Shutdown(SocketShutdown.Send); + await connection.WaitForConnectionClose(); + } + } + } + + [Fact] + public async Task ImmediateFinAfterThrowingClosesGracefully() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new ThrowingConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // FIN + connection.Shutdown(SocketShutdown.Send); + await connection.WaitForConnectionClose(); + } + } + } + + [Fact] + public async Task ImmediateShutdownAfterOnConnectionAsyncDoesNotCrash() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new AsyncConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + var stopTask = Task.CompletedTask; + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + stopTask = server.StopAsync(); + } + + await stopTask; + } + } + + [Fact] + public async Task ThrowingSynchronousConnectionAdapterDoesNotCrashServer() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new ThrowingConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Will throw because the exception in the connection adapter will close the connection. + await Assert.ThrowsAsync(async () => + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 1000", + "\r\n"); + + for (var i = 0; i < 1000; i++) + { + await connection.Send("a"); + await Task.Delay(5); + } + }); + } + } + } + + [Fact] + public async Task CanFlushAsyncWithConnectionAdapter() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + }; + + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async context => + { + await context.Response.WriteAsync("Hello "); + await context.Response.Body.FlushAsync(); + await context.Response.WriteAsync("World!"); + }, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {serviceContext.DateHeaderValue}", + "", + "Hello World!"); + } + } + } + + private class RewritingConnectionAdapter : IConnectionAdapter + { + private RewritingStream _rewritingStream; + + public bool IsHttps => false; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + _rewritingStream = new RewritingStream(context.ConnectionStream); + return Task.FromResult(new AdaptedConnection(_rewritingStream)); + } + + public int BytesRead => _rewritingStream.BytesRead; + } + + private class AsyncConnectionAdapter : IConnectionAdapter + { + public bool IsHttps => false; + + public async Task OnConnectionAsync(ConnectionAdapterContext context) + { + await Task.Delay(100); + return new AdaptedConnection(new RewritingStream(context.ConnectionStream)); + } + } + + private class ThrowingConnectionAdapter : IConnectionAdapter + { + public bool IsHttps => false; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + throw new Exception(); + } + } + + private class AdaptedConnection : IAdaptedConnection + { + public AdaptedConnection(Stream adaptedStream) + { + ConnectionStream = adaptedStream; + } + + public Stream ConnectionStream { get; } + + public void Dispose() + { + } + } + + private class RewritingStream : Stream + { + private readonly Stream _innerStream; + + public RewritingStream(Stream innerStream) + { + _innerStream = innerStream; + } + + public int BytesRead { get; private set; } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanWrite => _innerStream.CanWrite; + + public override long Length => _innerStream.Length; + + public override long Position + { + get + { + return _innerStream.Position; + } + set + { + _innerStream.Position = value; + } + } + + public override void Flush() + { + _innerStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + var actual = _innerStream.Read(buffer, offset, count); + + BytesRead += actual; + + return actual; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var actual = await _innerStream.ReadAsync(buffer, offset, count); + + BytesRead += actual; + + return actual; + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _innerStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + for (int i = 0; i < buffer.Length; i++) + { + if (buffer[i] == '?') + { + buffer[i] = (byte)'!'; + } + } + + _innerStream.Write(buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + for (int i = 0; i < buffer.Length; i++) + { + if (buffer[i] == '?') + { + buffer[i] = (byte)'!'; + } + } + + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/ConnectionLimitTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ConnectionLimitTests.cs new file mode 100644 index 0000000000..df8a3b51e0 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/ConnectionLimitTests.cs @@ -0,0 +1,212 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Tests; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class ConnectionLimitTests : LoggedTest + { + [Fact] + public async Task ResetsCountWhenConnectionClosed() + { + var requestTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var releasedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var lockedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var counter = new EventRaisingResourceCounter(ResourceCounter.Quota(1)); + counter.OnLock += (s, e) => lockedTcs.TrySetResult(e); + counter.OnRelease += (s, e) => releasedTcs.TrySetResult(null); + + using (var server = CreateServerWithMaxConnections(async context => + { + await context.Response.WriteAsync("Hello"); + await requestTcs.Task; + }, counter)) + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetAsKeepAlive(); ; + await connection.Receive("HTTP/1.1 200 OK"); + Assert.True(await lockedTcs.Task.DefaultTimeout()); + requestTcs.TrySetResult(null); + } + + await releasedTcs.Task.DefaultTimeout(); + } + + [Fact] + public async Task UpgradedConnectionsCountsAgainstDifferentLimit() + { + using (var server = CreateServerWithMaxConnections(async context => + { + var feature = context.Features.Get(); + if (feature.IsUpgradableRequest) + { + var stream = await feature.UpgradeAsync(); + // keep it running until aborted + while (!context.RequestAborted.IsCancellationRequested) + { + await Task.Delay(100); + } + } + }, max: 1)) + using (var disposables = new DisposableStack()) + { + var upgraded = server.CreateConnection(); + disposables.Push(upgraded); + + await upgraded.SendEmptyGetWithUpgrade(); + await upgraded.Receive("HTTP/1.1 101"); + // once upgraded, normal connection limit is decreased to allow room for more "normal" connections + + var connection = server.CreateConnection(); + disposables.Push(connection); + + await connection.SendEmptyGetAsKeepAlive(); + await connection.Receive("HTTP/1.1 200 OK"); + + using (var rejected = server.CreateConnection()) + { + try + { + // this may throw IOException, depending on how fast Kestrel closes the socket + await rejected.SendEmptyGetAsKeepAlive(); + } + catch { } + + // connection should close without sending any data + await rejected.WaitForConnectionClose().DefaultTimeout(); + } + } + } + + [Fact] + public async Task RejectsConnectionsWhenLimitReached() + { + const int max = 10; + var requestTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = CreateServerWithMaxConnections(async context => + { + await context.Response.WriteAsync("Hello"); + await requestTcs.Task; + }, max)) + using (var disposables = new DisposableStack()) + { + for (var i = 0; i < max; i++) + { + var connection = server.CreateConnection(); + disposables.Push(connection); + + await connection.SendEmptyGetAsKeepAlive(); + await connection.Receive("HTTP/1.1 200 OK"); + } + + // limit has been reached + for (var i = 0; i < 10; i++) + { + using (var connection = server.CreateConnection()) + { + try + { + // this may throw IOException, depending on how fast Kestrel closes the socket + await connection.SendEmptyGetAsKeepAlive(); + } + catch { } + + // connection should close without sending any data + await connection.WaitForConnectionClose().DefaultTimeout(); + } + } + + requestTcs.TrySetResult(null); + } + } + + [Fact(Skip = "https://github.com/aspnet/KestrelHttpServer/issues/2282")] + public async Task ConnectionCountingReturnsToZero() + { + const int count = 100; + var opened = 0; + var closed = 0; + var openedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var counter = new EventRaisingResourceCounter(ResourceCounter.Quota(uint.MaxValue)); + + counter.OnLock += (o, e) => + { + if (e && Interlocked.Increment(ref opened) >= count) + { + openedTcs.TrySetResult(null); + } + }; + + counter.OnRelease += (o, e) => + { + if (Interlocked.Increment(ref closed) >= count) + { + closedTcs.TrySetResult(null); + } + }; + + using (var server = CreateServerWithMaxConnections(_ => Task.CompletedTask, counter)) + { + // open a bunch of connections in parallel + Parallel.For(0, count, async i => + { + try + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetAsKeepAlive(); + await connection.Receive("HTTP/1.1 200"); + } + } + catch (Exception ex) + { + openedTcs.TrySetException(ex); + } + }); + + // wait until resource counter has called lock for each connection + await openedTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(120)); + // wait until resource counter has released all normal connections + await closedTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(120)); + Assert.Equal(count, opened); + Assert.Equal(count, closed); + } + } + + private TestServer CreateServerWithMaxConnections(RequestDelegate app, long max) + { + var serviceContext = new TestServiceContext(LoggerFactory); + serviceContext.ServerOptions.Limits.MaxConcurrentConnections = max; + return new TestServer(app, serviceContext); + } + + private TestServer CreateServerWithMaxConnections(RequestDelegate app, ResourceCounter concurrentConnectionCounter) + { + var serviceContext = new TestServiceContext(LoggerFactory); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + listenOptions.Use(next => + { + var middleware = new ConnectionLimitMiddleware(next, concurrentConnectionCounter, serviceContext.Log); + return middleware.OnConnectionAsync; + }); + + return new TestServer(app, serviceContext, listenOptions); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/DefaultHeaderTests.cs b/src/Servers/Kestrel/test/FunctionalTests/DefaultHeaderTests.cs new file mode 100644 index 0000000000..c9a564c68f --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/DefaultHeaderTests.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class DefaultHeaderTests : LoggedTest + { + [Fact] + public async Task TestDefaultHeaders() + { + var testContext = new TestServiceContext(LoggerFactory) + { + ServerOptions = { AddServerHeader = true } + }; + + using (var server = new TestServer(ctx => Task.CompletedTask, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.0", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Server: Kestrel", + "Content-Length: 0", + "", + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Server: Kestrel", + "Content-Length: 0", + "", + ""); + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/EventSourceTests.cs b/src/Servers/Kestrel/test/FunctionalTests/EventSourceTests.cs new file mode 100644 index 0000000000..b10b81d128 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/EventSourceTests.cs @@ -0,0 +1,112 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class EventSourceTests : LoggedTest, IDisposable + { + private readonly TestEventListener _listener = new TestEventListener(); + + public EventSourceTests() + { + _listener.EnableEvents(KestrelEventSource.Log, EventLevel.Verbose); + } + + [Fact] + public async Task EmitsConnectionStartAndStop() + { + string connectionId = null; + string requestId = null; + int port; + using (var server = new TestServer(context => + { + connectionId = context.Features.Get().ConnectionId; + requestId = context.TraceIdentifier; + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + port = server.Port; + using (var connection = server.CreateConnection()) + { + await connection.SendAll("GET / HTTP/1.1", + "Host:", + "", + "") + .DefaultTimeout(); + await connection.Receive("HTTP/1.1 200"); + } + } + + // capture list here as other tests executing in parallel may log events + Assert.NotNull(connectionId); + Assert.NotNull(requestId); + + var events = _listener.EventData.Where(e => e != null && GetProperty(e, "connectionId") == connectionId).ToList(); + + { + var start = Assert.Single(events, e => e.EventName == "ConnectionStart"); + Assert.All(new[] { "connectionId", "remoteEndPoint", "localEndPoint" }, p => Assert.Contains(p, start.PayloadNames)); + Assert.Equal($"127.0.0.1:{port}", GetProperty(start, "localEndPoint")); + } + { + var stop = Assert.Single(events, e => e.EventName == "ConnectionStop"); + Assert.All(new[] { "connectionId" }, p => Assert.Contains(p, stop.PayloadNames)); + Assert.Same(KestrelEventSource.Log, stop.EventSource); + } + { + var requestStart = Assert.Single(events, e => e.EventName == "RequestStart"); + Assert.All(new[] { "connectionId", "requestId" }, p => Assert.Contains(p, requestStart.PayloadNames)); + Assert.Equal(requestId, GetProperty(requestStart, "requestId")); + Assert.Same(KestrelEventSource.Log, requestStart.EventSource); + } + { + var requestStop = Assert.Single(events, e => e.EventName == "RequestStop"); + Assert.All(new[] { "connectionId", "requestId" }, p => Assert.Contains(p, requestStop.PayloadNames)); + Assert.Equal(requestId, GetProperty(requestStop, "requestId")); + Assert.Same(KestrelEventSource.Log, requestStop.EventSource); + } + } + + private string GetProperty(EventWrittenEventArgs data, string propName) + => data.Payload[data.PayloadNames.IndexOf(propName)] as string; + + private class TestEventListener : EventListener + { + private volatile bool _disposed; + private ConcurrentQueue _events = new ConcurrentQueue(); + + public IEnumerable EventData => _events; + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + if (!_disposed) + { + _events.Enqueue(eventData); + } + } + + public override void Dispose() + { + _disposed = true; + base.Dispose(); + } + } + + public void Dispose() + { + _listener.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/GeneratedCodeTests.cs b/src/Servers/Kestrel/test/FunctionalTests/GeneratedCodeTests.cs new file mode 100644 index 0000000000..862890542f --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/GeneratedCodeTests.cs @@ -0,0 +1,55 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#if NETCOREAPP2_1 +using System.IO; +using Microsoft.AspNetCore.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class GeneratedCodeTests + { + [Fact] + public void GeneratedCodeIsUpToDate() + { + var repositoryRoot = TestPathUtilities.GetSolutionRootDirectory("Microsoft.AspNetCore"); + + var httpHeadersGeneratedPath = Path.Combine(repositoryRoot, "src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs"); + var httpProtocolGeneratedPath = Path.Combine(repositoryRoot, "src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs"); + var httpUtilitiesGeneratedPath = Path.Combine(repositoryRoot, "src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.Generated.cs"); + + var testHttpHeadersGeneratedPath = Path.GetTempFileName(); + var testHttpProtocolGeneratedPath = Path.GetTempFileName(); + var testHttpUtilitiesGeneratedPath = Path.GetTempFileName(); + + try + { + var currentHttpHeadersGenerated = File.ReadAllText(httpHeadersGeneratedPath); + var currentHttpProtocolGenerated = File.ReadAllText(httpProtocolGeneratedPath); + var currentHttpUtilitiesGenerated = File.ReadAllText(httpUtilitiesGeneratedPath); + + CodeGenerator.Program.Run(testHttpHeadersGeneratedPath, testHttpProtocolGeneratedPath, testHttpUtilitiesGeneratedPath); + + var testHttpHeadersGenerated = File.ReadAllText(testHttpHeadersGeneratedPath); + var testHttpProtocolGenerated = File.ReadAllText(testHttpProtocolGeneratedPath); + var testHttpUtilitiesGenerated = File.ReadAllText(testHttpUtilitiesGeneratedPath); + + Assert.Equal(currentHttpHeadersGenerated, testHttpHeadersGenerated, ignoreLineEndingDifferences: true); + Assert.Equal(currentHttpProtocolGenerated, testHttpProtocolGenerated, ignoreLineEndingDifferences: true); + Assert.Equal(currentHttpUtilitiesGenerated, testHttpUtilitiesGenerated, ignoreLineEndingDifferences: true); + + } + finally + { + File.Delete(testHttpHeadersGeneratedPath); + File.Delete(testHttpProtocolGeneratedPath); + File.Delete(testHttpUtilitiesGeneratedPath); + } + } + } +} +#elif NET461 +#else +#error Target framework needs to be updated +#endif diff --git a/src/Servers/Kestrel/test/FunctionalTests/HttpConnectionManagerTests.cs b/src/Servers/Kestrel/test/FunctionalTests/HttpConnectionManagerTests.cs new file mode 100644 index 0000000000..ef8ac428b1 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/HttpConnectionManagerTests.cs @@ -0,0 +1,78 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class HttpConnectionManagerTests : LoggedTest + { +// This test causes MemoryPoolBlocks to be finalized which in turn causes an assert failure in debug builds. +#if !DEBUG + [ConditionalFact] + [NoDebuggerCondition] + public async Task CriticalErrorLoggedIfApplicationDoesntComplete() + { + //////////////////////////////////////////////////////////////////////////////////////// + // WARNING: This test will fail under a debugger because Task.s_currentActiveTasks // + // roots HttpConnection. // + //////////////////////////////////////////////////////////////////////////////////////// + + var logWh = new SemaphoreSlim(0); + var appStartedWh = new SemaphoreSlim(0); + + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.ApplicationNeverCompleted(It.IsAny())) + .Callback(() => + { + logWh.Release(); + }); + + using (var server = new TestServer(context => + { + appStartedWh.Release(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return tcs.Task; + }, + new TestServiceContext(new LoggerFactory(), mockTrace.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGet(); + + Assert.True(await appStartedWh.WaitAsync(TestConstants.DefaultTimeout)); + + // Close connection without waiting for a response + } + + var logWaitAttempts = 0; + + for (; !await logWh.WaitAsync(TimeSpan.FromSeconds(1)) && logWaitAttempts < 30; logWaitAttempts++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + Assert.True(logWaitAttempts < 10); + } + } +#endif + + private class NoDebuggerConditionAttribute : Attribute, ITestCondition + { + public bool IsMet => !Debugger.IsAttached; + public string SkipReason => "A debugger is attached."; + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/HttpProtocolSelectionTests.cs b/src/Servers/Kestrel/test/FunctionalTests/HttpProtocolSelectionTests.cs new file mode 100644 index 0000000000..e7bdaadcbf --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/HttpProtocolSelectionTests.cs @@ -0,0 +1,97 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class HttpProtocolSelectionTests : TestApplicationErrorLoggerLoggedTest + { + [Fact] + public Task Server_NoProtocols_Error() + { + return TestError(HttpProtocols.None, CoreStrings.EndPointRequiresAtLeastOneProtocol); + } + + [Fact] + public Task Server_Http1AndHttp2_Cleartext_Error() + { + return TestError(HttpProtocols.Http1AndHttp2, CoreStrings.EndPointRequiresTlsForHttp1AndHttp2); + } + + [Fact] + public Task Server_Http1Only_Cleartext_Success() + { + return TestSuccess(HttpProtocols.Http1, "GET / HTTP/1.1\r\nHost:\r\n\r\n", "HTTP/1.1 200 OK"); + } + + [Fact] + public Task Server_Http2Only_Cleartext_Success() + { + // Expect a SETTINGS frame (type 0x4) with no payload and no flags + return TestSuccess(HttpProtocols.Http2, Encoding.ASCII.GetString(Http2Connection.ClientPreface), "\x00\x00\x00\x04\x00\x00\x00\x00\x00"); + } + + private async Task TestSuccess(HttpProtocols serverProtocols, string request, string expectedResponse) + { + var builder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(IPAddress.Loopback, 0, listenOptions => + { + listenOptions.Protocols = serverProtocols; + }); + }) + .Configure(app => app.Run(context => Task.CompletedTask)); + + using (var host = builder.Build()) + { + host.Start(); + + using (var connection = new TestConnection(host.GetPort())) + { + await connection.Send(request); + await connection.Receive(expectedResponse); + } + } + } + + private async Task TestError(HttpProtocols serverProtocols, string expectedErrorMessage) + where TException : Exception + { + var builder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => options.Listen(IPAddress.Loopback, 0, listenOptions => + { + listenOptions.Protocols = serverProtocols; + })) + .Configure(app => app.Run(context => Task.CompletedTask)); + + using (var host = builder.Build()) + { + host.Start(); + + using (var connection = new TestConnection(host.GetPort())) + { + await connection.WaitForConnectionClose().DefaultTimeout(); + } + } + + Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error + && message.EventId.Id == 0 + && message.Message == expectedErrorMessage); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterOptionsTest.cs b/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterOptionsTest.cs new file mode 100644 index 0000000000..7d32d4be96 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterOptionsTest.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class HttpsConnectionAdapterOptionsTests + { + [Fact] + public void HandshakeTimeoutDefault() + { + Assert.Equal(TimeSpan.FromSeconds(10), new HttpsConnectionAdapterOptions().HandshakeTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutValidData))] + public void HandshakeTimeoutValid(TimeSpan value) + { + Assert.Equal(value, new HttpsConnectionAdapterOptions { HandshakeTimeout = value }.HandshakeTimeout); + } + + [Fact] + public void HandshakeTimeoutCanBeSetToInfinite() + { + Assert.Equal(TimeSpan.MaxValue, new HttpsConnectionAdapterOptions { HandshakeTimeout = Timeout.InfiniteTimeSpan }.HandshakeTimeout); + } + + [Theory] + [MemberData(nameof(TimeoutInvalidData))] + public void HandshakeTimeoutInvalid(TimeSpan value) + { + var exception = Assert.Throws(() => new HttpsConnectionAdapterOptions { HandshakeTimeout = value }); + + Assert.Equal("value", exception.ParamName); + Assert.StartsWith(CoreStrings.PositiveTimeSpanRequired, exception.Message); + } + + public static TheoryData TimeoutValidData => new TheoryData + { + TimeSpan.FromTicks(1), + TimeSpan.MaxValue, + }; + + public static TheoryData TimeoutInvalidData => new TheoryData + { + TimeSpan.MinValue, + TimeSpan.FromTicks(-1), + TimeSpan.Zero + }; + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterTests.cs b/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterTests.cs new file mode 100644 index 0000000000..1531e07b97 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/HttpsConnectionAdapterTests.cs @@ -0,0 +1,662 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class HttpsConnectionAdapterTests : LoggedTest + { + private static X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate(); + private static X509Certificate2 _x509Certificate2NoExt = TestResources.GetTestCertificate("no_extensions.pfx"); + + // https://github.com/aspnet/KestrelHttpServer/issues/240 + // This test currently fails on mono because of an issue with SslStream. + [Fact] + public async Task CanReadAndWriteWithHttpsConnectionAdapter() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions { ServerCertificate = _x509Certificate2 }) + } + }; + + using (var server = new TestServer(App, new TestServiceContext(LoggerFactory), listenOptions)) + { + var result = await HttpClientSlim.PostAsync($"https://localhost:{server.Port}/", + new FormUrlEncodedContent(new[] { + new KeyValuePair("content", "Hello World?") + }), + validateCertificate: false); + + Assert.Equal("content=Hello+World%3F", result); + } + } + + [Fact] + public async Task RequireCertificateFailsWhenNoCertificate() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = ClientCertificateMode.RequireCertificate + }) + } + }; + + + using (var server = new TestServer(App, new TestServiceContext(LoggerFactory), listenOptions)) + { + await Assert.ThrowsAnyAsync( + () => HttpClientSlim.GetStringAsync($"https://localhost:{server.Port}/")); + } + } + + [Fact] + public async Task AllowCertificateContinuesWhenNoCertificate() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = ClientCertificateMode.AllowCertificate + }) + } + }; + + using (var server = new TestServer(context => + { + var tlsFeature = context.Features.Get(); + Assert.NotNull(tlsFeature); + Assert.Null(tlsFeature.ClientCertificate); + return context.Response.WriteAsync("hello world"); + }, new TestServiceContext(LoggerFactory), listenOptions)) + { + var result = await HttpClientSlim.GetStringAsync($"https://localhost:{server.Port}/", validateCertificate: false); + Assert.Equal("hello world", result); + } + } + + [Fact] + public void ThrowsWhenNoServerCertificateIsProvided() + { + Assert.Throws(() => new HttpsConnectionAdapter( + new HttpsConnectionAdapterOptions()) + ); + } + + [Fact] + public async Task UsesProvidedServerCertificate() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions { ServerCertificate = _x509Certificate2 }) + } + }; + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2)); + } + } + } + + [Fact] + public async Task UsesProvidedServerCertificateSelector() + { + var selectorCalled = 0; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = (connection, name) => + { + Assert.NotNull(connection); + Assert.NotNull(connection.Features.Get()); +#if NETCOREAPP2_1 + Assert.Equal("localhost", name); +#else + Assert.Null(name); +#endif + selectorCalled++; + return _x509Certificate2; + } + }) + } + }; + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2)); + Assert.Equal(1, selectorCalled); + } + } + } + + [Fact] + public async Task UsesProvidedServerCertificateSelectorEachTime() + { + var selectorCalled = 0; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = (connection, name) => + { + Assert.NotNull(connection); + Assert.NotNull(connection.Features.Get()); +#if NETCOREAPP2_1 + Assert.Equal("localhost", name); +#else + Assert.Null(name); +#endif + selectorCalled++; + if (selectorCalled == 1) + { + return _x509Certificate2; + } + return _x509Certificate2NoExt; + } + }) + } + }; + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2)); + Assert.Equal(1, selectorCalled); + } + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2NoExt)); + Assert.Equal(2, selectorCalled); + } + } + } + + [Fact] + public async Task UsesProvidedServerCertificateSelectorValidatesEkus() + { + var selectorCalled = 0; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = (features, name) => + { + selectorCalled++; + return TestResources.GetTestCertificate("eku.code_signing.pfx"); + } + }) + } + }; + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await Assert.ThrowsAsync(() => + stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false)); + Assert.Equal(1, selectorCalled); + } + } + } + + [Fact] + public async Task UsesProvidedServerCertificateSelectorOverridesServerCertificate() + { + var selectorCalled = 0; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2NoExt, + ServerCertificateSelector = (connection, name) => + { + Assert.NotNull(connection); + Assert.NotNull(connection.Features.Get()); +#if NETCOREAPP2_1 + Assert.Equal("localhost", name); +#else + Assert.Null(name); +#endif + selectorCalled++; + return _x509Certificate2; + } + }) + } + }; + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + Assert.True(stream.RemoteCertificate.Equals(_x509Certificate2)); + Assert.Equal(1, selectorCalled); + } + } + } + + [Fact] + public async Task UsesProvidedServerCertificateSelectorFailsIfYouReturnNull() + { + var selectorCalled = 0; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = (features, name) => + { + selectorCalled++; + return null; + } + }) + } + }; + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await Assert.ThrowsAsync(() => + stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false)); + Assert.Equal(1, selectorCalled); + } + } + } + + [Fact] + public async Task CertificatePassedToHttpContext() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = ClientCertificateMode.RequireCertificate, + ClientCertificateValidation = (certificate, chain, sslPolicyErrors) => true + }) + } + }; + + using (var server = new TestServer(context => + { + var tlsFeature = context.Features.Get(); + Assert.NotNull(tlsFeature); + Assert.NotNull(tlsFeature.ClientCertificate); + Assert.NotNull(context.Connection.ClientCertificate); + return context.Response.WriteAsync("hello world"); + }, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + await AssertConnectionResult(stream, true); + } + } + } + + [Fact] + public async Task HttpsSchemePassedToRequestFeature() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions { ServerCertificate = _x509Certificate2 }) + } + }; + + using (var server = new TestServer(context => context.Response.WriteAsync(context.Request.Scheme), new TestServiceContext(LoggerFactory), listenOptions)) + { + var result = await HttpClientSlim.GetStringAsync($"https://localhost:{server.Port}/", validateCertificate: false); + Assert.Equal("https", result); + } + } + + [Fact] + public async Task DoesNotSupportTls10() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = ClientCertificateMode.RequireCertificate, + ClientCertificateValidation = (certificate, chain, sslPolicyErrors) => true + }) + } + }; + + using (var server = new TestServer(context => context.Response.WriteAsync("hello world"), new TestServiceContext(LoggerFactory), listenOptions)) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + using (var client = new TcpClient()) + { + var stream = await OpenSslStream(client, server); + var ex = await Assert.ThrowsAsync( + async () => await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls, false)); + } + } + } + + [Theory] + [InlineData(ClientCertificateMode.AllowCertificate)] + [InlineData(ClientCertificateMode.RequireCertificate)] + public async Task ClientCertificateValidationGetsCalledWithNotNullParameters(ClientCertificateMode mode) + { + var clientCertificateValidationCalled = false; + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = mode, + ClientCertificateValidation = (certificate, chain, sslPolicyErrors) => + { + clientCertificateValidationCalled = true; + Assert.NotNull(certificate); + Assert.NotNull(chain); + return true; + } + }) + } + }; + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + await AssertConnectionResult(stream, true); + Assert.True(clientCertificateValidationCalled); + } + } + } + + [Theory] + [InlineData(ClientCertificateMode.AllowCertificate)] + [InlineData(ClientCertificateMode.RequireCertificate)] + public async Task ValidationFailureRejectsConnection(ClientCertificateMode mode) + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = mode, + ClientCertificateValidation = (certificate, chain, sslPolicyErrors) => false + }) + } + }; + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + await AssertConnectionResult(stream, false); + } + } + } + + [Theory] + [InlineData(ClientCertificateMode.AllowCertificate)] + [InlineData(ClientCertificateMode.RequireCertificate)] + public async Task RejectsConnectionOnSslPolicyErrorsWhenNoValidation(ClientCertificateMode mode) + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = mode + }) + } + }; + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var client = new TcpClient()) + { + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + await AssertConnectionResult(stream, false); + } + } + } + + [Fact] + public async Task CertificatePassedToHttpContextIsNotDisposed() + { + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = _x509Certificate2, + ClientCertificateMode = ClientCertificateMode.RequireCertificate, + ClientCertificateValidation = (certificate, chain, sslPolicyErrors) => true + }) + } + }; + + RequestDelegate app = context => + { + var tlsFeature = context.Features.Get(); + Assert.NotNull(tlsFeature); + Assert.NotNull(tlsFeature.ClientCertificate); + Assert.NotNull(context.Connection.ClientCertificate); + Assert.NotNull(context.Connection.ClientCertificate.PublicKey); + return context.Response.WriteAsync("hello world"); + }; + + using (var server = new TestServer(app, new TestServiceContext(LoggerFactory), listenOptions)) + { + // SslStream is used to ensure the certificate is actually passed to the server + // HttpClient might not send the certificate because it is invalid or it doesn't match any + // of the certificate authorities sent by the server in the SSL handshake. + using (var client = new TcpClient()) + { + var stream = await OpenSslStream(client, server); + await stream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + await AssertConnectionResult(stream, true); + } + } + } + + [Theory] + [InlineData("no_extensions.pfx")] + public void AcceptsCertificateWithoutExtensions(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.Empty(cert.Extensions.OfType()); + + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = cert, + }); + } + + [Theory] + [InlineData("eku.server.pfx")] + [InlineData("eku.multiple_usages.pfx")] + public void ValidatesEnhancedKeyUsageOnCertificate(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.NotEmpty(cert.Extensions); + var eku = Assert.Single(cert.Extensions.OfType()); + Assert.NotEmpty(eku.EnhancedKeyUsages); + + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = cert, + }); + } + + [Theory] + [InlineData("eku.code_signing.pfx")] + [InlineData("eku.client.pfx")] + public void ThrowsForCertificatesMissingServerEku(string testCertName) + { + var certPath = TestResources.GetCertPath(testCertName); + TestOutputHelper.WriteLine("Loading " + certPath); + var cert = new X509Certificate2(certPath, "testPassword"); + Assert.NotEmpty(cert.Extensions); + var eku = Assert.Single(cert.Extensions.OfType()); + Assert.NotEmpty(eku.EnhancedKeyUsages); + + var ex = Assert.Throws(() => + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions + { + ServerCertificate = cert, + })); + + Assert.Equal(CoreStrings.FormatInvalidServerCertificateEku(cert.Thumbprint), ex.Message); + } + + private static async Task App(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + while (true) + { + var buffer = new byte[8192]; + var count = await request.Body.ReadAsync(buffer, 0, buffer.Length); + if (count == 0) + { + break; + } + await response.Body.WriteAsync(buffer, 0, count); + } + } + + private static async Task OpenSslStream(TcpClient client, TestServer server, X509Certificate2 clientCertificate = null) + { + await client.ConnectAsync("127.0.0.1", server.Port); + var stream = new SslStream(client.GetStream(), false, (sender, certificate, chain, errors) => true, + (sender, host, certificates, certificate, issuers) => clientCertificate ?? _x509Certificate2); + + return stream; + } + + private static async Task AssertConnectionResult(SslStream stream, bool success) + { + var request = Encoding.UTF8.GetBytes("GET / HTTP/1.0\r\n\r\n"); + await stream.WriteAsync(request, 0, request.Length); + var reader = new StreamReader(stream); + string line = null; + if (success) + { + line = await reader.ReadLineAsync(); + Assert.Equal("HTTP/1.1 200 OK", line); + } + else + { + try + { + line = await reader.ReadLineAsync(); + } + catch (IOException) { } + Assert.Null(line); + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/HttpsTests.cs b/src/Servers/Kestrel/test/FunctionalTests/HttpsTests.cs new file mode 100644 index 0000000000..9de8a29b48 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/HttpsTests.cs @@ -0,0 +1,523 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions.Internal; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class HttpsTests : LoggedTest + { + private KestrelServerOptions CreateServerOptions() + { + var serverOptions = new KestrelServerOptions(); + serverOptions.ApplicationServices = new ServiceCollection() + .AddLogging() + .BuildServiceProvider(); + return serverOptions; + } + + [Fact] + public void UseHttpsDefaultsToDefaultCert() + { + var serverOptions = CreateServerOptions(); + var defaultCert = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + serverOptions.DefaultCertificate = defaultCert; + + serverOptions.ListenLocalhost(5000, options => + { + options.UseHttps(); + }); + + Assert.False(serverOptions.IsDevCertLoaded); + + serverOptions.ListenLocalhost(5001, options => + { + options.UseHttps(opt => + { + // The default cert is applied after UseHttps. + Assert.Null(opt.ServerCertificate); + }); + }); + Assert.False(serverOptions.IsDevCertLoaded); + } + + [Fact] + public void ConfigureHttpsDefaultsNeverLoadsDefaultCert() + { + var serverOptions = CreateServerOptions(); + var testCert = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + serverOptions.ConfigureHttpsDefaults(options => + { + Assert.Null(options.ServerCertificate); + options.ServerCertificate = testCert; + options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + }); + serverOptions.ListenLocalhost(5000, options => + { + options.UseHttps(opt => + { + Assert.Equal(testCert, opt.ServerCertificate); + Assert.Equal(ClientCertificateMode.RequireCertificate, opt.ClientCertificateMode); + }); + }); + // Never lazy loaded + Assert.False(serverOptions.IsDevCertLoaded); + Assert.Null(serverOptions.DefaultCertificate); + } + + [Fact] + public void ConfigureCertSelectorNeverLoadsDefaultCert() + { + var serverOptions = CreateServerOptions(); + var testCert = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + serverOptions.ConfigureHttpsDefaults(options => + { + Assert.Null(options.ServerCertificate); + Assert.Null(options.ServerCertificateSelector); + options.ServerCertificateSelector = (features, name) => + { + return testCert; + }; + options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + }); + serverOptions.ListenLocalhost(5000, options => + { + options.UseHttps(opt => + { + Assert.Null(opt.ServerCertificate); + Assert.NotNull(opt.ServerCertificateSelector); + Assert.Equal(ClientCertificateMode.RequireCertificate, opt.ClientCertificateMode); + }); + }); + // Never lazy loaded + Assert.False(serverOptions.IsDevCertLoaded); + Assert.Null(serverOptions.DefaultCertificate); + } + + [Fact] + public async Task EmptyRequestLoggedAsDebug() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .ConfigureLogging(builder => builder.AddProvider(loggerProvider)) + .Configure(app => { }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (await HttpClientSlim.GetSocket(new Uri($"http://127.0.0.1:{host.GetPort()}/"))) + { + // Close socket immediately + } + + await loggerProvider.FilterLogger.LogTcs.Task.DefaultTimeout(); + } + + Assert.Equal(1, loggerProvider.FilterLogger.LastEventId.Id); + Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); + Assert.True(loggerProvider.ErrorLogger.TotalErrorsLogged == 0, + userMessage: string.Join(Environment.NewLine, loggerProvider.ErrorLogger.ErrorMessages)); + } + + [Fact] + public async Task ClientHandshakeFailureLoggedAsDebug() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .Configure(app => { }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket)) + { + // Send null bytes and close socket + await stream.WriteAsync(new byte[10], 0, 10); + } + + await loggerProvider.FilterLogger.LogTcs.Task.DefaultTimeout(); + } + + Assert.Equal(1, loggerProvider.FilterLogger.LastEventId.Id); + Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); + Assert.True(loggerProvider.ErrorLogger.TotalErrorsLogged == 0, + userMessage: string.Join(Environment.NewLine, loggerProvider.ErrorLogger.ErrorMessages)); + } + + // Regression test for https://github.com/aspnet/KestrelHttpServer/issues/1103#issuecomment-246971172 + [Fact] + public async Task DoesNotThrowObjectDisposedExceptionOnConnectionAbort() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .ConfigureLogging(builder => builder.AddProvider(loggerProvider)) + .Configure(app => app.Run(async httpContext => + { + var ct = httpContext.RequestAborted; + while (!ct.IsCancellationRequested) + { + try + { + await httpContext.Response.WriteAsync($"hello, world", ct); + await Task.Delay(1000, ct); + } + catch (TaskCanceledException) + { + // Don't regard connection abort as an error + } + } + })); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true)) + { + await sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls11 | SslProtocols.Tls12, + checkCertificateRevocation: false); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + + await sslStream.ReadAsync(new byte[32], 0, 32); + } + } + + Assert.False(loggerProvider.ErrorLogger.ObjectDisposedExceptionLogged); + } + + [Fact] + public async Task DoesNotThrowObjectDisposedExceptionFromWriteAsyncAfterConnectionIsAborted() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .ConfigureLogging(builder => builder.AddProvider(loggerProvider)) + .Configure(app => app.Run(async httpContext => + { + httpContext.Abort(); + try + { + await httpContext.Response.WriteAsync($"hello, world"); + tcs.SetResult(null); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + })); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true)) + { + await sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls11 | SslProtocols.Tls12, + checkCertificateRevocation: false); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + + await sslStream.ReadAsync(new byte[32], 0, 32); + } + } + + await tcs.Task.DefaultTimeout(); + } + + // Regression test for https://github.com/aspnet/KestrelHttpServer/issues/1693 + [Fact] + public async Task DoesNotThrowObjectDisposedExceptionOnEmptyConnection() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .ConfigureLogging(builder => builder.AddProvider(loggerProvider)) + .Configure(app => app.Run(httpContext => Task.CompletedTask)); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true)) + { + await sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls11 | SslProtocols.Tls12, + checkCertificateRevocation: false); + } + } + + Assert.False(loggerProvider.ErrorLogger.ObjectDisposedExceptionLogged); + } + + // Regression test for https://github.com/aspnet/KestrelHttpServer/pull/1197 + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + public void ConnectionFilterDoesNotLeakBlock() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .ConfigureLogging(builder => builder.AddProvider(loggerProvider)) + .Configure(app => { }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + + // Close socket immediately + socket.LingerState = new LingerOption(true, 0); + } + } + } + + [Fact] + public async Task HandshakeTimesOutAndIsLoggedAsDebug() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(o => + { + o.ServerCertificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + o.HandshakeTimeout = TimeSpan.FromSeconds(1); + }); + }); + }) + .ConfigureServices(AddTestLogging) + .Configure(app => app.Run(httpContext => Task.CompletedTask)); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + { + // No data should be sent and the connection should be closed in well under 30 seconds. + Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1).DefaultTimeout()); + } + } + + await loggerProvider.FilterLogger.LogTcs.Task.DefaultTimeout(); + Assert.Equal(2, loggerProvider.FilterLogger.LastEventId); + Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); + } + + [Fact] + public async Task ClientAttemptingToUseUnsupportedProtocolIsLoggedAsDebug() + { + var loggerProvider = new HandshakeErrorLoggerProvider(); + LoggerFactory.AddProvider(loggerProvider); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + }); + }) + .ConfigureServices(AddTestLogging) + .Configure(app => app.Run(httpContext => Task.CompletedTask)); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true)) + { + // SslProtocols.Tls is TLS 1.0 which isn't supported by Kestrel by default. + await Assert.ThrowsAsync(() => + sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls, + checkCertificateRevocation: false)); + } + } + + await loggerProvider.FilterLogger.LogTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + Assert.Equal(1, loggerProvider.FilterLogger.LastEventId); + Assert.Equal(LogLevel.Debug, loggerProvider.FilterLogger.LastLogLevel); + } + + private class HandshakeErrorLoggerProvider : ILoggerProvider + { + public HttpsConnectionFilterLogger FilterLogger { get; } = new HttpsConnectionFilterLogger(); + public ApplicationErrorLogger ErrorLogger { get; } = new ApplicationErrorLogger(); + + public ILogger CreateLogger(string categoryName) + { + if (categoryName == nameof(HttpsConnectionAdapter)) + { + return FilterLogger; + } + else + { + return ErrorLogger; + } + } + + public void Dispose() + { + } + } + + private class HttpsConnectionFilterLogger : ILogger + { + public LogLevel LastLogLevel { get; set; } + public EventId LastEventId { get; set; } + public TaskCompletionSource LogTcs { get; } = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { + LastLogLevel = logLevel; + LastEventId = eventId; + LogTcs.SetResult(null); + } + + public bool IsEnabled(LogLevel logLevel) + { + throw new NotImplementedException(); + } + + public IDisposable BeginScope(TState state) + { + throw new NotImplementedException(); + } + } + + private class ApplicationErrorLogger : ILogger + { + private List _errorMessages = new List(); + + public IEnumerable ErrorMessages => _errorMessages; + + public int TotalErrorsLogged => _errorMessages.Count; + + public bool ObjectDisposedExceptionLogged { get; set; } + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { + if (logLevel == LogLevel.Error) + { + var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}"; + _errorMessages.Add(log); + } + + if (exception is ObjectDisposedException) + { + ObjectDisposedExceptionLogged = true; + } + } + + public bool IsEnabled(LogLevel logLevel) + { + return true; + } + + public IDisposable BeginScope(TState state) + { + return NullScope.Instance; + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/KeepAliveTimeoutTests.cs b/src/Servers/Kestrel/test/FunctionalTests/KeepAliveTimeoutTests.cs new file mode 100644 index 0000000000..90c2b77cfa --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/KeepAliveTimeoutTests.cs @@ -0,0 +1,256 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +#if !INNER_LOOP +using System; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class KeepAliveTimeoutTests : LoggedTest + { + private static readonly TimeSpan _keepAliveTimeout = TimeSpan.FromSeconds(10); + private static readonly TimeSpan _longDelay = TimeSpan.FromSeconds(30); + private static readonly TimeSpan _shortDelay = TimeSpan.FromSeconds(_longDelay.TotalSeconds / 10); + + [Fact] + public Task TestKeepAliveTimeout() + { + // Delays in these tests cannot be much longer than expected. + // Call Task.Run() to get rid of Xunit's synchronization context, + // otherwise it can cause unexpectedly longer delays when multiple tests + // are running in parallel. These tests becomes flaky on slower + // hardware because the continuations for the delay tasks might take too long to be + // scheduled if running on Xunit's synchronization context. + return Task.Run(async () => + { + var longRunningCancellationTokenSource = new CancellationTokenSource(); + var upgradeCancellationTokenSource = new CancellationTokenSource(); + + using (var server = CreateServer(longRunningCancellationTokenSource.Token, upgradeCancellationTokenSource.Token)) + { + var tasks = new[] + { + ConnectionClosedWhenKeepAliveTimeoutExpires(server), + ConnectionKeptAliveBetweenRequests(server), + ConnectionNotTimedOutWhileRequestBeingSent(server), + ConnectionNotTimedOutWhileAppIsRunning(server, longRunningCancellationTokenSource), + ConnectionTimesOutWhenOpenedButNoRequestSent(server), + KeepAliveTimeoutDoesNotApplyToUpgradedConnections(server, upgradeCancellationTokenSource) + }; + + await Task.WhenAll(tasks); + } + }); + } + + private async Task ConnectionClosedWhenKeepAliveTimeoutExpires(TestServer server) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await ReceiveResponse(connection); + await connection.WaitForConnectionClose().TimeoutAfter(_longDelay); + } + } + + private async Task ConnectionKeptAliveBetweenRequests(TestServer server) + { + using (var connection = server.CreateConnection()) + { + for (var i = 0; i < 10; i++) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + // Don't change this to Task.Delay. See https://github.com/aspnet/KestrelHttpServer/issues/1684#issuecomment-330285740. + Thread.Sleep(_shortDelay); + } + + for (var i = 0; i < 10; i++) + { + await ReceiveResponse(connection); + } + } + } + + private async Task ConnectionNotTimedOutWhileRequestBeingSent(TestServer server) + { + using (var connection = server.CreateConnection()) + { + var cts = new CancellationTokenSource(); + cts.CancelAfter(_longDelay); + + await connection.Send( + "POST /consume HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + ""); + + while (!cts.IsCancellationRequested) + { + await connection.Send( + "1", + "a", + ""); + await Task.Delay(_shortDelay); + } + + await connection.Send( + "0", + "", + ""); + await ReceiveResponse(connection); + } + } + + private async Task ConnectionNotTimedOutWhileAppIsRunning(TestServer server, CancellationTokenSource cts) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET /longrunning HTTP/1.1", + "Host:", + "", + ""); + cts.CancelAfter(_longDelay); + + while (!cts.IsCancellationRequested) + { + await Task.Delay(1000); + } + + await ReceiveResponse(connection); + + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await ReceiveResponse(connection); + } + } + + private async Task ConnectionTimesOutWhenOpenedButNoRequestSent(TestServer server) + { + using (var connection = server.CreateConnection()) + { + await Task.Delay(_longDelay); + await connection.WaitForConnectionClose().TimeoutAfter(_longDelay); + } + } + + private async Task KeepAliveTimeoutDoesNotApplyToUpgradedConnections(TestServer server, CancellationTokenSource cts) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET /upgrade HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + await connection.Receive( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + ""); + await connection.ReceiveStartsWith("Date: "); + await connection.Receive( + "", + ""); + cts.CancelAfter(_longDelay); + + while (!cts.IsCancellationRequested) + { + await Task.Delay(1000); + } + + await connection.Receive("hello, world"); + } + } + + private TestServer CreateServer(CancellationToken longRunningCt, CancellationToken upgradeCt) + { + return new TestServer(httpContext => App(httpContext, longRunningCt, upgradeCt), new TestServiceContext(LoggerFactory) + { + // Use real SystemClock so timeouts trigger. + SystemClock = new SystemClock(), + ServerOptions = + { + AddServerHeader = false, + Limits = + { + KeepAliveTimeout = _keepAliveTimeout, + MinRequestBodyDataRate = null + } + } + }); + } + + private async Task App(HttpContext httpContext, CancellationToken longRunningCt, CancellationToken upgradeCt) + { + var ct = httpContext.RequestAborted; + var responseStream = httpContext.Response.Body; + var responseBytes = Encoding.ASCII.GetBytes("hello, world"); + + if (httpContext.Request.Path == "/longrunning") + { + while (!longRunningCt.IsCancellationRequested) + { + await Task.Delay(1000); + } + } + else if (httpContext.Request.Path == "/upgrade") + { + using (var stream = await httpContext.Features.Get().UpgradeAsync()) + { + while (!upgradeCt.IsCancellationRequested) + { + await Task.Delay(_longDelay); + } + + responseStream = stream; + } + } + else if (httpContext.Request.Path == "/consume") + { + var buffer = new byte[1024]; + while (await httpContext.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) ; + } + + await responseStream.WriteAsync(responseBytes, 0, responseBytes.Length); + } + + private async Task ReceiveResponse(TestConnection connection) + { + await connection.Receive( + "HTTP/1.1 200 OK", + ""); + await connection.ReceiveStartsWith("Date: "); + await connection.Receive( + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } +} +#endif \ No newline at end of file diff --git a/src/Servers/Kestrel/test/FunctionalTests/LoggingConnectionAdapterTests.cs b/src/Servers/Kestrel/test/FunctionalTests/LoggingConnectionAdapterTests.cs new file mode 100644 index 0000000000..728d9a16fb --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/LoggingConnectionAdapterTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class LoggingConnectionAdapterTests : LoggedTest + { + [Fact] + public async Task LoggingConnectionAdapterCanBeAddedBeforeAndAfterHttpsAdapter() + { + var host = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + listenOptions.UseConnectionLogging(); + listenOptions.UseHttps(TestResources.TestCertificatePath, "testPassword"); + listenOptions.UseConnectionLogging(); + }); + }) + .Configure(app => + { + app.Run(context => + { + context.Response.ContentLength = 12; + return context.Response.WriteAsync("Hello World!"); + }); + }) + .Build(); + + using (host) + { + await host.StartAsync(); + + var response = await HttpClientSlim.GetStringAsync($"https://localhost:{host.GetPort()}/", validateCertificate: false) + .DefaultTimeout(); + + Assert.Equal("Hello World!", response); + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBodySizeTests.cs b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBodySizeTests.cs new file mode 100644 index 0000000000..4bc664f558 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBodySizeTests.cs @@ -0,0 +1,496 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class MaxRequestBodySizeTests : LoggedTest + { + [Fact] + public async Task RejectsRequestWithContentLengthHeaderExceedingGlobalLimit() + { + // 4 GiB + var globalMaxRequestBodySize = 0x100000000; + BadHttpRequestException requestRejectedEx = null; + + using (var server = new TestServer(async context => + { + var buffer = new byte[1]; + requestRejectedEx = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + throw requestRejectedEx; + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = globalMaxRequestBodySize } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: " + (globalMaxRequestBodySize + 1), + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx.Message); + } + + [Fact] + public async Task RejectsRequestWithContentLengthHeaderExceedingPerRequestLimit() + { + // 8 GiB + var globalMaxRequestBodySize = 0x200000000; + // 4 GiB + var perRequestMaxRequestBodySize = 0x100000000; + BadHttpRequestException requestRejectedEx = null; + + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + Assert.Equal(globalMaxRequestBodySize, feature.MaxRequestBodySize); + + // Disable the MaxRequestBodySize prior to calling Request.Body.ReadAsync(); + feature.MaxRequestBodySize = perRequestMaxRequestBodySize; + + var buffer = new byte[1]; + requestRejectedEx = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + throw requestRejectedEx; + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = globalMaxRequestBodySize } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: " + (perRequestMaxRequestBodySize + 1), + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx.Message); + } + + [Fact] + public async Task DoesNotRejectRequestWithContentLengthHeaderExceedingGlobalLimitIfLimitDisabledPerRequest() + { + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + Assert.Equal(0, feature.MaxRequestBodySize); + + // Disable the MaxRequestBodySize prior to calling Request.Body.ReadAsync(); + feature.MaxRequestBodySize = null; + + var buffer = new byte[1]; + + Assert.Equal(1, await context.Request.Body.ReadAsync(buffer, 0, 1)); + Assert.Equal(buffer[0], (byte)'A'); + Assert.Equal(0, await context.Request.Body.ReadAsync(buffer, 0, 1)); + + context.Response.ContentLength = 1; + await context.Response.Body.WriteAsync(buffer, 0, 1); + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = 0 } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + "A"); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 1", + "", + "A"); + } + } + } + + [Fact] + public async Task DoesNotRejectBodylessGetRequestWithZeroMaxRequestBodySize() + { + using (var server = new TestServer(context => context.Request.Body.CopyToAsync(Stream.Null), + new TestServiceContext { ServerOptions = { Limits = { MaxRequestBodySize = 0 } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task SettingMaxRequestBodySizeAfterReadingFromRequestBodyThrows() + { + var perRequestMaxRequestBodySize = 0x10; + var payloadSize = perRequestMaxRequestBodySize + 1; + var payload = new string('A', payloadSize); + InvalidOperationException invalidOpEx = null; + + using (var server = new TestServer(async context => + { + var buffer = new byte[1]; + Assert.Equal(1, await context.Request.Body.ReadAsync(buffer, 0, 1)); + + var feature = context.Features.Get(); + Assert.Equal(new KestrelServerLimits().MaxRequestBodySize, feature.MaxRequestBodySize); + Assert.True(feature.IsReadOnly); + + invalidOpEx = Assert.Throws(() => + feature.MaxRequestBodySize = perRequestMaxRequestBodySize); + throw invalidOpEx; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: " + payloadSize, + "", + payload); + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(invalidOpEx); + Assert.Equal(CoreStrings.MaxRequestBodySizeCannotBeModifiedAfterRead, invalidOpEx.Message); + } + + [Fact] + public async Task SettingMaxRequestBodySizeAfterUpgradingRequestThrows() + { + InvalidOperationException invalidOpEx = null; + + using (var server = new TestServer(async context => + { + var upgradeFeature = context.Features.Get(); + var stream = await upgradeFeature.UpgradeAsync(); + + var feature = context.Features.Get(); + Assert.Equal(new KestrelServerLimits().MaxRequestBodySize, feature.MaxRequestBodySize); + Assert.True(feature.IsReadOnly); + + invalidOpEx = Assert.Throws(() => + feature.MaxRequestBodySize = 0x10); + throw invalidOpEx; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send("GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + await connection.ReceiveForcedEnd(); + } + } + + Assert.NotNull(invalidOpEx); + Assert.Equal(CoreStrings.MaxRequestBodySizeCannotBeModifiedForUpgradedRequests, invalidOpEx.Message); + } + + [Fact] + public async Task EveryReadFailsWhenContentLengthHeaderExceedsGlobalLimit() + { + BadHttpRequestException requestRejectedEx1 = null; + BadHttpRequestException requestRejectedEx2 = null; + + using (var server = new TestServer(async context => + { + var buffer = new byte[1]; + requestRejectedEx1 = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + requestRejectedEx2 = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + throw requestRejectedEx2; + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = 0 } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: " + (new KestrelServerLimits().MaxRequestBodySize + 1), + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx1); + Assert.NotNull(requestRejectedEx2); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx1.Message); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx2.Message); + } + + [Fact] + public async Task ChunkFramingAndExtensionsCountTowardsRequestBodySize() + { + var chunkedPayload = "5;random chunk extension\r\nHello\r\n6\r\n World\r\n0\r\n\r\n"; + var globalMaxRequestBodySize = chunkedPayload.Length - 1; + BadHttpRequestException requestRejectedEx = null; + + using (var server = new TestServer(async context => + { + var buffer = new byte[11]; + requestRejectedEx = await Assert.ThrowsAsync(async () => + { + var count = 0; + do + { + count = await context.Request.Body.ReadAsync(buffer, 0, 11); + } while (count != 0); + }); + + throw requestRejectedEx; + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = globalMaxRequestBodySize } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + chunkedPayload); + await connection.ReceiveForcedEnd( + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx.Message); + } + + [Fact] + public async Task TrailingHeadersDoNotCountTowardsRequestBodySize() + { + var chunkedPayload = $"5;random chunk extension\r\nHello\r\n6\r\n World\r\n0\r\n"; + var trailingHeaders = "Trailing-Header: trailing-value\r\n\r\n"; + var globalMaxRequestBodySize = chunkedPayload.Length; + + using (var server = new TestServer(async context => + { + var offset = 0; + var count = 0; + var buffer = new byte[11]; + + do + { + count = await context.Request.Body.ReadAsync(buffer, offset, 11 - offset); + offset += count; + } while (count != 0); + + Assert.Equal("Hello World", Encoding.ASCII.GetString(buffer)); + Assert.Equal("trailing-value", context.Request.Headers["Trailing-Header"].ToString()); + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = globalMaxRequestBodySize } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + chunkedPayload + trailingHeaders); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task PerRequestMaxRequestBodySizeGetsReset() + { + var chunkedPayload = "5;random chunk extension\r\nHello\r\n6\r\n World\r\n0\r\n\r\n"; + var globalMaxRequestBodySize = chunkedPayload.Length - 1; + var firstRequest = true; + BadHttpRequestException requestRejectedEx = null; + + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + Assert.Equal(globalMaxRequestBodySize, feature.MaxRequestBodySize); + + var buffer = new byte[11]; + var count = 0; + + if (firstRequest) + { + firstRequest = false; + feature.MaxRequestBodySize = chunkedPayload.Length; + + do + { + count = await context.Request.Body.ReadAsync(buffer, 0, 11); + } while (count != 0); + } + else + { + requestRejectedEx = await Assert.ThrowsAsync(async () => + { + do + { + count = await context.Request.Body.ReadAsync(buffer, 0, 11); + } while (count != 0); + }); + + throw requestRejectedEx; + } + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = globalMaxRequestBodySize } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + chunkedPayload + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + chunkedPayload); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx.Message); + } + + [Fact] + public async Task EveryReadFailsWhenChunkedPayloadExceedsGlobalLimit() + { + BadHttpRequestException requestRejectedEx1 = null; + BadHttpRequestException requestRejectedEx2 = null; + + using (var server = new TestServer(async context => + { + var buffer = new byte[1]; + requestRejectedEx1 = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + requestRejectedEx2 = await Assert.ThrowsAsync( + async () => await context.Request.Body.ReadAsync(buffer, 0, 1)); + throw requestRejectedEx2; + }, + new TestServiceContext(LoggerFactory) { ServerOptions = { Limits = { MaxRequestBodySize = 0 } } })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1\r\n"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 413 Payload Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(requestRejectedEx1); + Assert.NotNull(requestRejectedEx2); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx1.Message); + Assert.Equal(CoreStrings.BadRequest_RequestBodyTooLarge, requestRejectedEx2.Message); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBufferSizeTests.cs b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBufferSizeTests.cs new file mode 100644 index 0000000000..64522f2202 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestBufferSizeTests.cs @@ -0,0 +1,347 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class MaxRequestBufferSizeTests : LoggedTest + { + private const int _dataLength = 20 * 1024 * 1024; + + private static readonly string[] _requestLines = new[] + { + "POST / HTTP/1.0\r\n", + $"Content-Length: {_dataLength}\r\n", + "\r\n" + }; + + public static IEnumerable LargeUploadData + { + get + { + var maxRequestBufferSizeValues = new Tuple[] { + // Smallest buffer that can hold a test request line without causing + // the server to hang waiting for the end of the request line or + // a header line. + Tuple.Create((long?)(_requestLines.Max(line => line.Length)), true), + + // Small buffer, but large enough to hold all request headers. + Tuple.Create((long?)16 * 1024, true), + + // Default buffer. + Tuple.Create((long?)1024 * 1024, true), + + // Larger than default, but still significantly lower than data, so client should be paused. + // On Windows, the client is usually paused around (MaxRequestBufferSize + 700,000). + // On Linux, the client is usually paused around (MaxRequestBufferSize + 10,000,000). + Tuple.Create((long?)5 * 1024 * 1024, true), + + // Even though maxRequestBufferSize < _dataLength, client should not be paused since the + // OS-level buffers in client and/or server will handle the overflow. + Tuple.Create((long?)_dataLength - 1, false), + + // Buffer is exactly the same size as data. Exposed race condition where + // the connection was resumed after socket was disconnected. + Tuple.Create((long?)_dataLength, false), + + // Largest possible buffer, should never trigger backpressure. + Tuple.Create((long?)long.MaxValue, false), + + // Disables all code related to computing and limiting the size of the input buffer. + Tuple.Create((long?)null, false) + }; + var sslValues = new[] { true, false }; + + return from maxRequestBufferSize in maxRequestBufferSizeValues + from ssl in sslValues + select new object[] { + maxRequestBufferSize.Item1, + ssl, + maxRequestBufferSize.Item2 + }; + } + } + + [Theory] + [MemberData(nameof(LargeUploadData))] + public async Task LargeUpload(long? maxRequestBufferSize, bool connectionAdapter, bool expectPause) + { + // Parameters + var data = new byte[_dataLength]; + var bytesWrittenTimeout = TimeSpan.FromMilliseconds(100); + var bytesWrittenPollingInterval = TimeSpan.FromMilliseconds(bytesWrittenTimeout.TotalMilliseconds / 10); + var maxSendSize = 4096; + + // Initialize data with random bytes + (new Random()).NextBytes(data); + + var startReadingRequestBody = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientFinishedSendingRequestBody = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var lastBytesWritten = DateTime.MaxValue; + + using (var host = StartWebHost(maxRequestBufferSize, data, connectionAdapter, startReadingRequestBody, clientFinishedSendingRequestBody)) + { + var port = host.GetPort(); + using (var socket = CreateSocket(port)) + using (var stream = new NetworkStream(socket)) + { + await WritePostRequestHeaders(stream, data.Length); + + var bytesWritten = 0; + + Func sendFunc = async () => + { + while (bytesWritten < data.Length) + { + var size = Math.Min(data.Length - bytesWritten, maxSendSize); + await stream.WriteAsync(data, bytesWritten, size).ConfigureAwait(false); + bytesWritten += size; + lastBytesWritten = DateTime.Now; + } + + Assert.Equal(data.Length, bytesWritten); + clientFinishedSendingRequestBody.TrySetResult(null); + }; + + var sendTask = sendFunc(); + + if (expectPause) + { + // The minimum is (maxRequestBufferSize - maxSendSize + 1), since if bytesWritten is + // (maxRequestBufferSize - maxSendSize) or smaller, the client should be able to + // complete another send. + var minimumExpectedBytesWritten = maxRequestBufferSize.Value - maxSendSize + 1; + + // The maximum is harder to determine, since there can be OS-level buffers in both the client + // and server, which allow the client to send more than maxRequestBufferSize before getting + // paused. We assume the combined buffers are smaller than the difference between + // data.Length and maxRequestBufferSize. + var maximumExpectedBytesWritten = data.Length - 1; + + // Block until the send task has gone a while without writing bytes AND + // the bytes written exceeds the minimum expected. This indicates the server buffer + // is full. + // + // If the send task is paused before the expected number of bytes have been + // written, keep waiting since the pause may have been caused by something else + // like a slow machine. + while ((DateTime.Now - lastBytesWritten) < bytesWrittenTimeout || + bytesWritten < minimumExpectedBytesWritten) + { + await Task.Delay(bytesWrittenPollingInterval); + } + + // Verify the number of bytes written before the client was paused. + Assert.InRange(bytesWritten, minimumExpectedBytesWritten, maximumExpectedBytesWritten); + + // Tell server to start reading request body + startReadingRequestBody.TrySetResult(null); + + // Wait for sendTask to finish sending the remaining bytes + await sendTask; + } + else + { + // Ensure all bytes can be sent before the server starts reading + await sendTask; + + // Tell server to start reading request body + startReadingRequestBody.TrySetResult(null); + } + + using (var reader = new StreamReader(stream, Encoding.ASCII)) + { + var response = reader.ReadToEnd(); + Assert.Contains($"bytesRead: {data.Length}", response); + } + } + } + } + + [Fact] + public async Task ServerShutsDownGracefullyWhenMaxRequestBufferSizeExceeded() + { + // Parameters + var data = new byte[_dataLength]; + var bytesWrittenTimeout = TimeSpan.FromMilliseconds(100); + var bytesWrittenPollingInterval = TimeSpan.FromMilliseconds(bytesWrittenTimeout.TotalMilliseconds / 10); + var maxSendSize = 4096; + + var startReadingRequestBody = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientFinishedSendingRequestBody = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var lastBytesWritten = DateTime.MaxValue; + + using (var host = StartWebHost(16 * 1024, data, false, startReadingRequestBody, clientFinishedSendingRequestBody)) + { + var port = host.GetPort(); + using (var socket = CreateSocket(port)) + using (var stream = new NetworkStream(socket)) + { + await WritePostRequestHeaders(stream, data.Length); + + var bytesWritten = 0; + + Func sendFunc = async () => + { + while (bytesWritten < data.Length) + { + var size = Math.Min(data.Length - bytesWritten, maxSendSize); + await stream.WriteAsync(data, bytesWritten, size).ConfigureAwait(false); + bytesWritten += size; + lastBytesWritten = DateTime.Now; + } + + clientFinishedSendingRequestBody.TrySetResult(null); + }; + + var ignore = sendFunc(); + + // The minimum is (maxRequestBufferSize - maxSendSize + 1), since if bytesWritten is + // (maxRequestBufferSize - maxSendSize) or smaller, the client should be able to + // complete another send. + var minimumExpectedBytesWritten = (16 * 1024) - maxSendSize + 1; + + // The maximum is harder to determine, since there can be OS-level buffers in both the client + // and server, which allow the client to send more than maxRequestBufferSize before getting + // paused. We assume the combined buffers are smaller than the difference between + // data.Length and maxRequestBufferSize. + var maximumExpectedBytesWritten = data.Length - 1; + + // Block until the send task has gone a while without writing bytes AND + // the bytes written exceeds the minimum expected. This indicates the server buffer + // is full. + // + // If the send task is paused before the expected number of bytes have been + // written, keep waiting since the pause may have been caused by something else + // like a slow machine. + while ((DateTime.Now - lastBytesWritten) < bytesWrittenTimeout || + bytesWritten < minimumExpectedBytesWritten) + { + await Task.Delay(bytesWrittenPollingInterval); + } + + // Verify the number of bytes written before the client was paused. + Assert.InRange(bytesWritten, minimumExpectedBytesWritten, maximumExpectedBytesWritten); + + // Dispose host prior to closing connection to verify the server doesn't throw during shutdown + // if a connection no longer has alloc and read callbacks configured. + host.Dispose(); + } + } + } + + private IWebHost StartWebHost(long? maxRequestBufferSize, + byte[] expectedBody, + bool useConnectionAdapter, + TaskCompletionSource startReadingRequestBody, + TaskCompletionSource clientFinishedSendingRequestBody) + { + var host = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0), listenOptions => + { + if (useConnectionAdapter) + { + listenOptions.ConnectionAdapters.Add(new PassThroughConnectionAdapter()); + } + }); + + options.Limits.MaxRequestBufferSize = maxRequestBufferSize; + + if (maxRequestBufferSize.HasValue && + maxRequestBufferSize.Value < options.Limits.MaxRequestLineSize) + { + options.Limits.MaxRequestLineSize = (int)maxRequestBufferSize; + } + + if (maxRequestBufferSize.HasValue && + maxRequestBufferSize.Value < options.Limits.MaxRequestHeadersTotalSize) + { + options.Limits.MaxRequestHeadersTotalSize = (int)maxRequestBufferSize; + } + + options.Limits.MinRequestBodyDataRate = null; + }) + .UseContentRoot(Directory.GetCurrentDirectory()) + .Configure(app => app.Run(async context => + { + await startReadingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120)); + + var buffer = new byte[expectedBody.Length]; + var bytesRead = 0; + while (bytesRead < buffer.Length) + { + bytesRead += await context.Request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead); + } + + await clientFinishedSendingRequestBody.Task.TimeoutAfter(TimeSpan.FromSeconds(120)); + + // Verify client didn't send extra bytes + if (await context.Request.Body.ReadAsync(new byte[1], 0, 1) != 0) + { + context.Response.StatusCode = StatusCodes.Status500InternalServerError; + await context.Response.WriteAsync("Client sent more bytes than expectedBody.Length"); + return; + } + + // Verify bytes received match expectedBody + for (int i = 0; i < expectedBody.Length; i++) + { + if (buffer[i] != expectedBody[i]) + { + context.Response.StatusCode = StatusCodes.Status500InternalServerError; + await context.Response.WriteAsync($"Bytes received do not match expectedBody at position {i}"); + return; + } + } + + await context.Response.WriteAsync($"bytesRead: {bytesRead.ToString()}"); + })) + .Build(); + + host.Start(); + + return host; + } + + private static Socket CreateSocket(int port) + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + // Timeouts large enough to prevent false positives, but small enough to fail quickly. + socket.SendTimeout = 10 * 1000; + socket.ReceiveTimeout = 120 * 1000; + + socket.Connect(IPAddress.Loopback, port); + + return socket; + } + + private static async Task WritePostRequestHeaders(Stream stream, int contentLength) + { + using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize: 1024, leaveOpen: true)) + { + foreach (var line in _requestLines) + { + await writer.WriteAsync(line).ConfigureAwait(false); + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/MaxRequestLineSizeTests.cs b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestLineSizeTests.cs new file mode 100644 index 0000000000..6007877262 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/MaxRequestLineSizeTests.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class MaxRequestLineSizeTests : LoggedTest + { + [Theory] + [InlineData("GET / HTTP/1.1\r\nHost:\r\n\r\n", 16)] + [InlineData("GET / HTTP/1.1\r\nHost:\r\n\r\n", 17)] + [InlineData("GET / HTTP/1.1\r\nHost:\r\n\r\n", 137)] + [InlineData("POST /abc/de HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 23)] + [InlineData("POST /abc/de HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 24)] + [InlineData("POST /abc/de HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 287)] + [InlineData("PUT /abc/de?f=ghi HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 28)] + [InlineData("PUT /abc/de?f=ghi HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 29)] + [InlineData("PUT /abc/de?f=ghi HTTP/1.1\r\nHost:\r\nContent-Length: 0\r\n\r\n", 589)] + [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\nHost:\r\n\r\n", 40)] + [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\nHost:\r\n\r\n", 41)] + [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\nHost:\r\n\r\n", 1027)] + public async Task ServerAcceptsRequestLineWithinLimit(string request, int limit) + { + using (var server = CreateServer(limit)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send(request); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Theory] + [InlineData("GET / HTTP/1.1\r\n")] + [InlineData("POST /abc/de HTTP/1.1\r\n")] + [InlineData("PUT /abc/de?f=ghi HTTP/1.1\r\n")] + [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\n")] + public async Task ServerRejectsRequestLineExceedingLimit(string requestLine) + { + using (var server = CreateServer(requestLine.Length - 1)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll(requestLine); + await connection.ReceiveForcedEnd( + "HTTP/1.1 414 URI Too Long", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + private TestServer CreateServer(int maxRequestLineSize) + { + return new TestServer(async httpContext => await httpContext.Response.WriteAsync("hello, world"), new TestServiceContext(LoggerFactory) + { + ServerOptions = new KestrelServerOptions + { + AddServerHeader = false, + Limits = + { + MaxRequestLineSize = maxRequestLineSize + } + } + }); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/Properties/AssemblyInfo.cs b/src/Servers/Kestrel/test/FunctionalTests/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..5e05b4461e --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/Properties/AssemblyInfo.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.Extensions.Logging.Testing; +#if MACOS +using Xunit; +#endif + +[assembly: ShortClassName] +#if MACOS +[assembly: CollectionBehavior(DisableTestParallelization = true)] +#endif diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestBodyTimeoutTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestBodyTimeoutTests.cs new file mode 100644 index 0000000000..9d292e53fc --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestBodyTimeoutTests.cs @@ -0,0 +1,209 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class RequestBodyTimeoutTests : LoggedTest + { + [Fact] + public async Task RequestTimesOutWhenRequestBodyNotReceivedAtSpecifiedMinimumRate() + { + var gracePeriod = TimeSpan.FromSeconds(5); + var systemClock = new MockSystemClock(); + var serviceContext = new TestServiceContext(LoggerFactory) + { + SystemClock = systemClock, + DateHeaderValueManager = new DateHeaderValueManager(systemClock) + }; + + var appRunningEvent = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(context => + { + context.Features.Get().MinDataRate = + new MinDataRate(bytesPerSecond: 1, gracePeriod: gracePeriod); + + // The server must call Request.Body.ReadAsync() *before* the test sets systemClock.UtcNow (which is triggered by the + // server calling appRunningEvent.SetResult(null)). If systemClock.UtcNow is set first, it's possible for the test to fail + // due to the following race condition: + // + // 1. [test] systemClock.UtcNow += gracePeriod + TimeSpan.FromSeconds(1); + // 2. [server] Heartbeat._timer is triggered, which calls HttpConnection.Tick() + // 3. [server] HttpConnection.Tick() calls HttpConnection.CheckForReadDataRateTimeout() + // 4. [server] HttpConnection.CheckForReadDataRateTimeout() is a no-op, since _readTimingEnabled is false, + // since Request.Body.ReadAsync() has not been called yet + // 5. [server] HttpConnection.Tick() sets _lastTimestamp = timestamp + // 6. [server] Request.Body.ReadAsync() is called + // 6. [test] systemClock.UtcNow is never updated again, so server timestamp is never updated, + // so HttpConnection.CheckForReadDataRateTimeout() is always a no-op until test fails + // + // This is a pretty tight race, since the once-per-second Heartbeat._timer needs to fire between the test updating + // systemClock.UtcNow and the server calling Request.Body.ReadAsync(). But it happened often enough to cause + // test flakiness in our CI (https://github.com/aspnet/KestrelHttpServer/issues/2539). + // + // For verification, I was able to induce the race by adding a sleep in the RequestDelegate: + // appRunningEvent.SetResult(null); + // Thread.Sleep(5000); + // return context.Request.Body.ReadAsync(new byte[1], 0, 1); + + var readTask = context.Request.Body.ReadAsync(new byte[1], 0, 1); + appRunningEvent.SetResult(null); + return readTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + + await appRunningEvent.Task.DefaultTimeout(); + systemClock.UtcNow += gracePeriod + TimeSpan.FromSeconds(1); + + await connection.Receive( + "HTTP/1.1 408 Request Timeout", + ""); + await connection.ReceiveForcedEnd( + "Connection: close", + $"Date: {serviceContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task RequestTimesOutWhenNotDrainedWithinDrainTimeoutPeriod() + { + // This test requires a real clock since we can't control when the drain timeout is set + var systemClock = new SystemClock(); + var serviceContext = new TestServiceContext(LoggerFactory) + { + SystemClock = systemClock, + DateHeaderValueManager = new DateHeaderValueManager(systemClock), + }; + + var appRunningEvent = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(context => + { + context.Features.Get().MinDataRate = null; + + appRunningEvent.SetResult(null); + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + + await appRunningEvent.Task.DefaultTimeout(); + + await connection.Receive( + "HTTP/1.1 200 OK", + ""); + await connection.ReceiveStartsWith( + "Date: "); + // Disconnected due to the timeout + await connection.ReceiveForcedEnd( + "Content-Length: 0", + "", + ""); + } + } + + Assert.Contains(TestSink.Writes, w => w.EventId.Id == 32 && w.LogLevel == LogLevel.Information); + Assert.Contains(TestSink.Writes, w => w.EventId.Id == 33 && w.LogLevel == LogLevel.Information); + } + + [Fact] + public async Task ConnectionClosedEvenIfAppSwallowsException() + { + var gracePeriod = TimeSpan.FromSeconds(5); + var systemClock = new MockSystemClock(); + var serviceContext = new TestServiceContext(LoggerFactory) + { + SystemClock = systemClock, + DateHeaderValueManager = new DateHeaderValueManager(systemClock) + }; + + var appRunningTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var exceptionSwallowedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + context.Features.Get().MinDataRate = + new MinDataRate(bytesPerSecond: 1, gracePeriod: gracePeriod); + + // See comment in RequestTimesOutWhenRequestBodyNotReceivedAtSpecifiedMinimumRate for + // why we call ReadAsync before setting the appRunningEvent. + var readTask = context.Request.Body.ReadAsync(new byte[1], 0, 1); + appRunningTcs.SetResult(null); + + try + { + await readTask; + } + catch (BadHttpRequestException ex) when (ex.StatusCode == 408) + { + exceptionSwallowedTcs.SetResult(null); + } + catch (Exception ex) + { + exceptionSwallowedTcs.SetException(ex); + } + + var response = "hello, world"; + context.Response.ContentLength = response.Length; + await context.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + + await appRunningTcs.Task.DefaultTimeout(); + systemClock.UtcNow += gracePeriod + TimeSpan.FromSeconds(1); + await exceptionSwallowedTcs.Task.DefaultTimeout(); + + await connection.Receive( + "HTTP/1.1 200 OK", + ""); + await connection.ReceiveForcedEnd( + $"Date: {serviceContext.DateHeaderValue}", + "Content-Length: 12", + "", + "hello, world"); + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestHeaderLimitsTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestHeaderLimitsTests.cs new file mode 100644 index 0000000000..6de4c37a99 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestHeaderLimitsTests.cs @@ -0,0 +1,158 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class RequestHeaderLimitsTests : LoggedTest + { + [Theory] + [InlineData(0, 1)] + [InlineData(0, 1337)] + [InlineData(1, 0)] + [InlineData(1, 1)] + [InlineData(1, 1337)] + [InlineData(5, 0)] + [InlineData(5, 1)] + [InlineData(5, 1337)] + public async Task ServerAcceptsRequestWithHeaderTotalSizeWithinLimit(int headerCount, int extraLimit) + { + var headers = MakeHeaders(headerCount); + + using (var server = CreateServer(maxRequestHeadersTotalSize: headers.Length + extraLimit)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Theory] + [InlineData(0, 1)] + [InlineData(0, 1337)] + [InlineData(1, 1)] + [InlineData(1, 2)] + [InlineData(1, 1337)] + [InlineData(5, 5)] + [InlineData(5, 6)] + [InlineData(5, 1337)] + public async Task ServerAcceptsRequestWithHeaderCountWithinLimit(int headerCount, int maxHeaderCount) + { + var headers = MakeHeaders(headerCount); + + using (var server = CreateServer(maxRequestHeaderCount: maxHeaderCount)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + public async Task ServerRejectsRequestWithHeaderTotalSizeOverLimit(int headerCount) + { + var headers = MakeHeaders(headerCount); + + using (var server = CreateServer(maxRequestHeadersTotalSize: headers.Length - 1)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 431 Request Header Fields Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [InlineData(2, 1)] + [InlineData(5, 1)] + [InlineData(5, 4)] + public async Task ServerRejectsRequestWithHeaderCountOverLimit(int headerCount, int maxHeaderCount) + { + var headers = MakeHeaders(headerCount); + + using (var server = CreateServer(maxRequestHeaderCount: maxHeaderCount)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 431 Request Header Fields Too Large", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + private static string MakeHeaders(int count) + { + const string host = "Host:\r\n"; + if (count <= 1) return host; + + return string.Join("", new[] { host } + .Concat(Enumerable + .Range(0, count -1) + .Select(i => $"Header-{i}: value{i}\r\n"))); + } + + private TestServer CreateServer(int? maxRequestHeaderCount = null, int? maxRequestHeadersTotalSize = null) + { + var options = new KestrelServerOptions { AddServerHeader = false }; + + if (maxRequestHeaderCount.HasValue) + { + options.Limits.MaxRequestHeaderCount = maxRequestHeaderCount.Value; + } + + if (maxRequestHeadersTotalSize.HasValue) + { + options.Limits.MaxRequestHeadersTotalSize = maxRequestHeadersTotalSize.Value; + } + + return new TestServer(async httpContext => await httpContext.Response.WriteAsync("hello, world"), new TestServiceContext(LoggerFactory) + { + ServerOptions = options + }); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestHeadersTimeoutTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestHeadersTimeoutTests.cs new file mode 100644 index 0000000000..57fe8fc274 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestHeadersTimeoutTests.cs @@ -0,0 +1,145 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class RequestHeadersTimeoutTests : LoggedTest + { + private static readonly TimeSpan RequestHeadersTimeout = TimeSpan.FromSeconds(10); + private static readonly TimeSpan LongDelay = TimeSpan.FromSeconds(30); + private static readonly TimeSpan ShortDelay = TimeSpan.FromSeconds(LongDelay.TotalSeconds / 10); + + [Fact] + public async Task TestRequestHeadersTimeout() + { + using (var server = CreateServer()) + { + var tasks = new[] + { + ConnectionAbortedWhenRequestHeadersNotReceivedInTime(server, "Host:\r\n"), + ConnectionAbortedWhenRequestHeadersNotReceivedInTime(server, "Host:\r\nContent-Length: 1\r\n"), + ConnectionAbortedWhenRequestHeadersNotReceivedInTime(server, "Host:\r\nContent-Length: 1\r\n\r"), + RequestHeadersTimeoutCanceledAfterHeadersReceived(server), + ConnectionAbortedWhenRequestLineNotReceivedInTime(server, "P"), + ConnectionAbortedWhenRequestLineNotReceivedInTime(server, "POST / HTTP/1.1\r"), + TimeoutNotResetOnEachRequestLineCharacterReceived(server) + }; + + await Task.WhenAll(tasks); + } + } + + private async Task ConnectionAbortedWhenRequestHeadersNotReceivedInTime(TestServer server, string headers) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + headers); + await ReceiveTimeoutResponse(connection); + } + } + + private async Task RequestHeadersTimeoutCanceledAfterHeadersReceived(TestServer server) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + await Task.Delay(RequestHeadersTimeout); + await connection.Send( + "a"); + await ReceiveResponse(connection); + } + } + + private async Task ConnectionAbortedWhenRequestLineNotReceivedInTime(TestServer server, string requestLine) + { + using (var connection = server.CreateConnection()) + { + await connection.Send(requestLine); + await ReceiveTimeoutResponse(connection); + } + } + + private async Task TimeoutNotResetOnEachRequestLineCharacterReceived(TestServer server) + { + using (var connection = server.CreateConnection()) + { + await Assert.ThrowsAsync(async () => + { + foreach (var ch in "POST / HTTP/1.1\r\nHost:\r\n\r\n") + { + await connection.Send(ch.ToString()); + await Task.Delay(ShortDelay); + } + }); + } + } + + private TestServer CreateServer() + { + return new TestServer(async httpContext => + { + await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + await httpContext.Response.WriteAsync("hello, world"); + }, + new TestServiceContext(LoggerFactory) + { + // Use real SystemClock so timeouts trigger. + SystemClock = new SystemClock(), + ServerOptions = + { + AddServerHeader = false, + Limits = + { + RequestHeadersTimeout = RequestHeadersTimeout, + MinRequestBodyDataRate = null + } + } + }); + } + + private async Task ReceiveResponse(TestConnection connection) + { + await connection.Receive( + "HTTP/1.1 200 OK", + ""); + await connection.ReceiveStartsWith("Date: "); + await connection.Receive( + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + + private async Task ReceiveTimeoutResponse(TestConnection connection) + { + await connection.Receive( + "HTTP/1.1 408 Request Timeout", + "Connection: close", + ""); + await connection.ReceiveStartsWith("Date: "); + await connection.ReceiveForcedEnd( + "Content-Length: 0", + "", + ""); + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestTargetProcessingTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestTargetProcessingTests.cs new file mode 100644 index 0000000000..6bf18030a3 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestTargetProcessingTests.cs @@ -0,0 +1,134 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class RequestTargetProcessingTests : LoggedTest + { + [Fact] + public async Task RequestPathIsNotNormalized() + { + var testContext = new TestServiceContext(LoggerFactory); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + using (var server = new TestServer(async context => + { + Assert.Equal("/\u0041\u030A/B/\u0041\u030A", context.Request.Path.Value); + + context.Response.Headers.ContentLength = 11; + await context.Response.WriteAsync("Hello World"); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET /%41%CC%8A/A/../B/%41%CC%8A HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + } + + [Theory] + [InlineData("/")] + [InlineData("/.")] + [InlineData("/..")] + [InlineData("/./.")] + [InlineData("/./..")] + [InlineData("/../.")] + [InlineData("/../..")] + [InlineData("/path")] + [InlineData("/path?foo=1&bar=2")] + [InlineData("/hello%20world")] + [InlineData("/hello%20world?foo=1&bar=2")] + [InlineData("/base/path")] + [InlineData("/base/path?foo=1&bar=2")] + [InlineData("/base/hello%20world")] + [InlineData("/base/hello%20world?foo=1&bar=2")] + public async Task RequestFeatureContainsRawTarget(string requestTarget) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async context => + { + Assert.Equal(requestTarget, context.Features.Get().RawTarget); + + context.Response.Headers["Content-Length"] = new[] { "11" }; + await context.Response.WriteAsync("Hello World"); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + $"GET {requestTarget} HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + } + + [Theory] + [InlineData(HttpMethod.Options, "*")] + [InlineData(HttpMethod.Connect, "host")] + public async Task NonPathRequestTargetSetInRawTarget(HttpMethod method, string requestTarget) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async context => + { + Assert.Equal(requestTarget, context.Features.Get().RawTarget); + Assert.Empty(context.Request.Path.Value); + Assert.Empty(context.Request.PathBase.Value); + Assert.Empty(context.Request.QueryString.Value); + + context.Response.Headers["Content-Length"] = new[] { "11" }; + await context.Response.WriteAsync("Hello World"); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + var host = method == HttpMethod.Connect + ? requestTarget + : string.Empty; + + await connection.Send( + $"{HttpUtilities.MethodToString(method)} {requestTarget} HTTP/1.1", + $"Host: {host}", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs new file mode 100644 index 0000000000..b528e80b3a --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs @@ -0,0 +1,1928 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Moq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class RequestTests : LoggedTest + { + private const int _connectionStartedEventId = 1; + private const int _connectionResetEventId = 19; + private static readonly int _semaphoreWaitTimeout = Debugger.IsAttached ? 10000 : 2500; + + public static TheoryData ConnectionAdapterData => new TheoryData + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + } + }; + + [Theory] + [InlineData(10 * 1024 * 1024, true)] + // In the following dataset, send at least 2GB. + // Never change to a lower value, otherwise regression testing for + // https://github.com/aspnet/KestrelHttpServer/issues/520#issuecomment-188591242 + // will be lost. + [InlineData((long)int.MaxValue + 1, false)] + public void LargeUpload(long contentLength, bool checkBytes) + { + const int bufferLength = 1024 * 1024; + Assert.True(contentLength % bufferLength == 0, $"{nameof(contentLength)} sent must be evenly divisible by {bufferLength}."); + Assert.True(bufferLength % 256 == 0, $"{nameof(bufferLength)} must be evenly divisible by 256"); + + var builder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel(options => + { + options.Limits.MaxRequestBodySize = contentLength; + options.Limits.MinRequestBodyDataRate = null; + }) + .UseUrls("http://127.0.0.1:0/") + .Configure(app => + { + app.Run(async context => + { + // Read the full request body + long total = 0; + var receivedBytes = new byte[bufferLength]; + var received = 0; + while ((received = await context.Request.Body.ReadAsync(receivedBytes, 0, receivedBytes.Length)) > 0) + { + if (checkBytes) + { + for (var i = 0; i < received; i++) + { + // Do not use Assert.Equal here, it is to slow for this hot path + Assert.True((byte)((total + i) % 256) == receivedBytes[i], "Data received is incorrect"); + } + } + + total += received; + } + + await context.Response.WriteAsync(total.ToString(CultureInfo.InvariantCulture)); + }); + }); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.Send(Encoding.ASCII.GetBytes($"POST / HTTP/1.0\r\nContent-Length: {contentLength}\r\n\r\n")); + + var contentBytes = new byte[bufferLength]; + + if (checkBytes) + { + for (var i = 0; i < contentBytes.Length; i++) + { + contentBytes[i] = (byte)i; + } + } + + for (var i = 0; i < contentLength / contentBytes.Length; i++) + { + socket.Send(contentBytes); + } + + var response = new StringBuilder(); + var responseBytes = new byte[4096]; + var received = 0; + while ((received = socket.Receive(responseBytes)) > 0) + { + response.Append(Encoding.ASCII.GetString(responseBytes, 0, received)); + } + + Assert.Contains(contentLength.ToString(CultureInfo.InvariantCulture), response.ToString()); + } + } + } + + [Fact] + public Task RemoteIPv4Address() + { + return TestRemoteIPAddress("127.0.0.1", "127.0.0.1", "127.0.0.1"); + } + + [ConditionalFact(Skip="https://github.com/aspnet/KestrelHttpServer/issues/2406")] + [IPv6SupportedCondition] + public Task RemoteIPv6Address() + { + return TestRemoteIPAddress("[::1]", "[::1]", "::1"); + } + + [Fact] + public async Task DoesNotHangOnConnectionCloseRequest() + { + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + await context.Response.WriteAsync("hello, world"); + }); + }); + + using (var host = builder.Build()) + using (var client = new HttpClient()) + { + host.Start(); + + client.DefaultRequestHeaders.Connection.Clear(); + client.DefaultRequestHeaders.Connection.Add("close"); + + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + } + } + + [Fact] + public async Task StreamsAreNotPersistedAcrossRequests() + { + var requestBodyPersisted = false; + var responseBodyPersisted = false; + + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + if (context.Request.Body is MemoryStream) + { + requestBodyPersisted = true; + } + + if (context.Response.Body is MemoryStream) + { + responseBodyPersisted = true; + } + + context.Request.Body = new MemoryStream(); + context.Response.Body = new MemoryStream(); + + await context.Response.WriteAsync("hello, world"); + }); + }); + + using (var host = builder.Build()) + { + host.Start(); + + using (var client = new HttpClient { BaseAddress = new Uri($"http://127.0.0.1:{host.GetPort()}") }) + { + await client.GetAsync("/"); + await client.GetAsync("/"); + + Assert.False(requestBodyPersisted); + Assert.False(responseBodyPersisted); + } + } + } + + [Fact] + public void CanUpgradeRequestWithConnectionKeepAliveUpgradeHeader() + { + var dataRead = false; + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + var stream = await context.Features.Get().UpgradeAsync(); + var data = new byte[3]; + var bytesRead = 0; + + while (bytesRead < 3) + { + bytesRead += await stream.ReadAsync(data, bytesRead, data.Length - bytesRead); + } + + dataRead = Encoding.ASCII.GetString(data, 0, 3) == "abc"; + }); + }); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nConnection: keep-alive, upgrade\r\n\r\n")); + socket.Send(Encoding.ASCII.GetBytes("abc")); + + while (socket.Receive(new byte[1024]) > 0) ; + } + } + + Assert.True(dataRead); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + public async Task ConnectionResetPriorToRequestIsLoggedAsDebug() + { + var connectionStarted = new SemaphoreSlim(0); + var connectionReset = new SemaphoreSlim(0); + var loggedHigherThanDebug = false; + + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + Logger.Log(logLevel, eventId, state, exception, formatter); + if (eventId.Id == _connectionStartedEventId) + { + connectionStarted.Release(); + } + else if (eventId.Id == _connectionResetEventId) + { + connectionReset.Release(); + } + + if (logLevel > LogLevel.Debug) + { + loggedHigherThanDebug = true; + } + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(mockLoggerFactory.Object))) + { + using (var connection = server.CreateConnection()) + { + // Wait until connection is established + Assert.True(await connectionStarted.WaitAsync(TestConstants.DefaultTimeout)); + + connection.Reset(); + } + + // If the reset is correctly logged as Debug, the wait below should complete shortly. + // This check MUST come before disposing the server, otherwise there's a race where the RST + // is still in flight when the connection is aborted, leading to the reset never being received + // and therefore not logged. + Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout)); + } + + Assert.False(loggedHigherThanDebug); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + public async Task ConnectionResetBetweenRequestsIsLoggedAsDebug() + { + var connectionReset = new SemaphoreSlim(0); + var loggedHigherThanDebug = false; + + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + Logger.Log(logLevel, eventId, state, exception, formatter); + if (eventId.Id == _connectionResetEventId) + { + connectionReset.Release(); + } + + if (logLevel > LogLevel.Debug) + { + loggedHigherThanDebug = true; + } + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(mockLoggerFactory.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + // Make sure the response is fully received, so a write failure (e.g. EPIPE) doesn't cause + // a more critical log message. + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + connection.Reset(); + // Force a reset + } + + // If the reset is correctly logged as Debug, the wait below should complete shortly. + // This check MUST come before disposing the server, otherwise there's a race where the RST + // is still in flight when the connection is aborted, leading to the reset never being received + // and therefore not logged. + Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout)); + } + + Assert.False(loggedHigherThanDebug); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + public async Task ConnectionResetMidRequestIsLoggedAsDebug() + { + var requestStarted = new SemaphoreSlim(0); + var connectionReset = new SemaphoreSlim(0); + var connectionClosing = new SemaphoreSlim(0); + var loggedHigherThanDebug = false; + + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + Logger.Log(logLevel, eventId, state, exception, formatter); + var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}"; + TestOutputHelper.WriteLine(log); + + if (eventId.Id == _connectionResetEventId) + { + connectionReset.Release(); + } + + if (logLevel > LogLevel.Debug) + { + loggedHigherThanDebug = true; + } + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + using (var server = new TestServer(async context => + { + requestStarted.Release(); + await connectionClosing.WaitAsync(); + }, + new TestServiceContext(mockLoggerFactory.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGet(); + + // Wait until connection is established + Assert.True(await requestStarted.WaitAsync(TestConstants.DefaultTimeout), "request should have started"); + + connection.Reset(); + } + + // If the reset is correctly logged as Debug, the wait below should complete shortly. + // This check MUST come before disposing the server, otherwise there's a race where the RST + // is still in flight when the connection is aborted, leading to the reset never being received + // and therefore not logged. + Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout), "Connection reset event should have been logged"); + connectionClosing.Release(); + } + + Assert.False(loggedHigherThanDebug, "Logged event should not have been higher than debug."); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + public async Task ThrowsOnReadAfterConnectionError() + { + var requestStarted = new SemaphoreSlim(0); + var connectionReset = new SemaphoreSlim(0); + var appDone = new SemaphoreSlim(0); + var expectedExceptionThrown = false; + + var builder = TransportSelector.GetWebHostBuilder() + .ConfigureServices(AddTestLogging) + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .Configure(app => app.Run(async context => + { + requestStarted.Release(); + Assert.True(await connectionReset.WaitAsync(_semaphoreWaitTimeout)); + + try + { + await context.Request.Body.ReadAsync(new byte[1], 0, 1); + } + catch (ConnectionResetException) + { + expectedExceptionThrown = true; + } + + appDone.Release(); + })); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.LingerState = new LingerOption(true, 0); + socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nContent-Length: 1\r\n\r\n")); + Assert.True(await requestStarted.WaitAsync(_semaphoreWaitTimeout)); + } + + connectionReset.Release(); + + Assert.True(await appDone.WaitAsync(_semaphoreWaitTimeout)); + Assert.True(expectedExceptionThrown); + } + } + + [Fact] + public async Task RequestAbortedTokenFiredOnClientFIN() + { + var appStarted = new SemaphoreSlim(0); + var requestAborted = new SemaphoreSlim(0); + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .ConfigureServices(AddTestLogging) + .Configure(app => app.Run(async context => + { + appStarted.Release(); + + var token = context.RequestAborted; + token.Register(() => requestAborted.Release(2)); + await requestAborted.WaitAsync().DefaultTimeout(); + })); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n")); + await appStarted.WaitAsync(); + socket.Shutdown(SocketShutdown.Send); + await requestAborted.WaitAsync().DefaultTimeout(); + } + } + } + + [Fact] + public void AbortingTheConnectionSendsFIN() + { + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0") + .ConfigureServices(AddTestLogging) + .Configure(app => app.Run(context => + { + context.Abort(); + return Task.CompletedTask; + })); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n")); + int result = socket.Receive(new byte[32]); + Assert.Equal(0, result); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnClientFIN(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var appStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(context => + { + appStartedTcs.SetResult(null); + + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await appStartedTcs.Task.DefaultTimeout(); + + connection.Shutdown(SocketShutdown.Send); + + await connectionClosedTcs.Task.DefaultTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnServerFIN(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(context => + { + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + await connectionClosedTcs.Task.DefaultTimeout(); + + await connection.ReceiveEnd($"HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedTokenFiresOnServerAbort(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(context => + { + var connectionLifetimeFeature = context.Features.Get(); + connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); + + context.Abort(); + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await connectionClosedTcs.Task.DefaultTimeout(); + await connection.ReceiveForcedEnd(); + } + } + } + + [Theory] + [InlineData("http://localhost/abs/path", "/abs/path", null)] + [InlineData("https://localhost/abs/path", "/abs/path", null)] // handles mismatch scheme + [InlineData("https://localhost:22/abs/path", "/abs/path", null)] // handles mismatched ports + [InlineData("https://differenthost/abs/path", "/abs/path", null)] // handles mismatched hostname + [InlineData("http://localhost/", "/", null)] + [InlineData("http://root@contoso.com/path", "/path", null)] + [InlineData("http://root:password@contoso.com/path", "/path", null)] + [InlineData("https://localhost/", "/", null)] + [InlineData("http://localhost", "/", null)] + [InlineData("http://127.0.0.1/", "/", null)] + [InlineData("http://[::1]/", "/", null)] + [InlineData("http://[::1]:8080/", "/", null)] + [InlineData("http://localhost?q=123&w=xyz", "/", "123")] + [InlineData("http://localhost/?q=123&w=xyz", "/", "123")] + [InlineData("http://localhost/path?q=123&w=xyz", "/path", "123")] + [InlineData("http://localhost/path%20with%20space?q=abc%20123", "/path with space", "abc 123")] + public async Task CanHandleRequestsWithUrlInAbsoluteForm(string requestUrl, string expectedPath, string queryValue) + { + var pathTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var rawTargetTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var queryTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + pathTcs.TrySetResult(context.Request.Path); + queryTcs.TrySetResult(context.Request.Query); + rawTargetTcs.TrySetResult(context.Features.Get().RawTarget); + await context.Response.WriteAsync("Done"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + var requestTarget = new Uri(requestUrl, UriKind.Absolute); + var host = requestTarget.Authority; + if (requestTarget.IsDefaultPort) + { + host += ":" + requestTarget.Port; + } + + await connection.Send( + $"GET {requestUrl} HTTP/1.1", + "Content-Length: 0", + $"Host: {host}", + "", + ""); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "4", + "Done") + .DefaultTimeout(); + + await Task.WhenAll(pathTcs.Task, rawTargetTcs.Task, queryTcs.Task).DefaultTimeout(); + Assert.Equal(new PathString(expectedPath), pathTcs.Task.Result); + Assert.Equal(requestUrl, rawTargetTcs.Task.Result); + if (queryValue == null) + { + Assert.False(queryTcs.Task.Result.ContainsKey("q")); + } + else + { + Assert.Equal(queryValue, queryTcs.Task.Result["q"]); + } + } + } + } + + [Fact] + public async Task AppCanSetTraceIdentifier() + { + const string knownId = "xyz123"; + using (var server = new TestServer(async context => + { + context.TraceIdentifier = knownId; + await context.Response.WriteAsync(context.TraceIdentifier); + }, new TestServiceContext(LoggerFactory))) + { + var requestId = await HttpClientSlim.GetStringAsync($"http://{server.EndPoint}") + .DefaultTimeout(); + Assert.Equal(knownId, requestId); + } + } + + [Fact] + public async Task TraceIdentifierIsUnique() + { + const int identifierLength = 22; + const int iterations = 10; + + using (var server = new TestServer(async context => + { + Assert.Equal(identifierLength, Encoding.ASCII.GetByteCount(context.TraceIdentifier)); + context.Response.ContentLength = identifierLength; + await context.Response.WriteAsync(context.TraceIdentifier); + }, new TestServiceContext(LoggerFactory))) + { + var usedIds = new ConcurrentBag(); + var uri = $"http://{server.EndPoint}"; + + // requests on separate connections in parallel + Parallel.For(0, iterations, async i => + { + var id = await HttpClientSlim.GetStringAsync(uri); + Assert.DoesNotContain(id, usedIds.ToArray()); + usedIds.Add(id); + }); + + // requests on same connection + using (var connection = server.CreateConnection()) + { + var buffer = new char[identifierLength]; + for (var i = 0; i < iterations; i++) + { + await connection.SendEmptyGet(); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Content-Length: {identifierLength}", + "", + "").DefaultTimeout(); + + var read = await connection.Reader.ReadAsync(buffer, 0, identifierLength); + Assert.Equal(identifierLength, read); + var id = new string(buffer, 0, read); + Assert.DoesNotContain(id, usedIds.ToArray()); + usedIds.Add(id); + } + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http11KeptAliveByDefault(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoAppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "Content-Length: 7", + "", + "Goodbye"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 7", + "", + "Goodbye"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10NotKeptAliveByDefault(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 11", + "", + "Hello World"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "", + "Hello World"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10KeepAlive(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoAppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + "POST / HTTP/1.0", + "Content-Length: 7", + "", + "Goodbye"); + await connection.Receive( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "\r\n"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 7", + "", + "Goodbye"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10KeepAliveNotHonoredIfResponseContentLengthNotSet(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "\r\n"); + + await connection.Send( + "POST / HTTP/1.0", + "Connection: keep-alive", + "Content-Length: 7", + "", + "Goodbye"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "", + "Goodbye"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http10KeepAliveHonoredIfResponseContentLengthSet(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoAppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 11", + "Connection: keep-alive", + "", + "Hello World"); + + await connection.Receive( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + + await connection.Send( + "POST / HTTP/1.0", + "Connection: keep-alive", + "Content-Length: 11", + "", + "Hello Again"); + + await connection.Receive( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello Again"); + + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 7", + "", + "Goodbye"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 7", + "", + "Goodbye"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Expect100ContinueHonored(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoAppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Expect: 100-continue", + "Connection: close", + "Content-Length: 11", + "\r\n"); + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + await connection.Send("Hello World"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroContentLengthAssumedOnNonKeepAliveRequestsWithoutContentLengthOrTransferEncodingHeader(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + // This will hang if 0 content length is not assumed by the server + Assert.Equal(0, await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1).DefaultTimeout()); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Use Send instead of SendEnd to ensure the connection will remain open while + // the app runs and reads 0 bytes from the body nonetheless. This checks that + // https://github.com/aspnet/KestrelHttpServer/issues/1104 is not regressing. + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + // FIN callbacks are scheduled so run inline to make this test more reliable + testContext.Scheduler = PipeScheduler.Inline; + + using (var server = new TestServer(TestApp.EchoAppChunked, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1"); + connection.Shutdown(SocketShutdown.Send); + await connection.ReceiveForcedEnd(); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 7"); + connection.Shutdown(SocketShutdown.Send); + await connection.ReceiveForcedEnd(); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task RequestsCanBeAbortedMidRead(ListenOptions listenOptions) + { + const int applicationAbortedConnectionId = 34; + + var testContext = new TestServiceContext(LoggerFactory); + + var readTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var registrationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestId = 0; + + using (var server = new TestServer(async httpContext => + { + requestId++; + + var response = httpContext.Response; + var request = httpContext.Request; + var lifetime = httpContext.Features.Get(); + + lifetime.RequestAborted.Register(() => registrationTcs.TrySetResult(requestId)); + + if (requestId == 1) + { + response.Headers["Content-Length"] = new[] { "5" }; + + await response.WriteAsync("World"); + } + else + { + var readTask = request.Body.CopyToAsync(Stream.Null); + + lifetime.Abort(); + + try + { + await readTask; + } + catch (Exception ex) + { + readTcs.SetException(ex); + throw; + } + + readTcs.SetException(new Exception("This shouldn't be reached.")); + } + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Full request and response + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "Hello"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 5", + "", + "World"); + + // Never send the body so CopyToAsync always fails. + await connection.Send("POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + ""); + await connection.WaitForConnectionClose(); + } + } + + await Assert.ThrowsAsync(async () => await readTcs.Task); + + // The cancellation token for only the last request should be triggered. + var abortedRequestId = await registrationTcs.Task; + Assert.Equal(2, abortedRequestId); + + Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel" && + w.EventId == applicationAbortedConnectionId)); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ServerCanAbortConnectionAfterUnobservedClose(ListenOptions listenOptions) + { + const int connectionPausedEventId = 4; + const int connectionFinSentEventId = 7; + const int maxRequestBufferSize = 4096; + + var readCallbackUnwired = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientClosedConnection = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverClosedConnection = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appFuncCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + if (eventId.Id == connectionPausedEventId) + { + readCallbackUnwired.TrySetResult(null); + } + else if (eventId.Id == connectionFinSentEventId) + { + serverClosedConnection.SetResult(null); + } + + Logger.Log(logLevel, eventId, state, exception, formatter); + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var testContext = new TestServiceContext(mockLoggerFactory.Object) + { + Log = mockKestrelTrace.Object, + ServerOptions = + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = maxRequestBufferSize, + MaxRequestHeadersTotalSize = maxRequestBufferSize, + } + } + }; + + var scratchBuffer = new byte[maxRequestBufferSize * 8]; + + using (var server = new TestServer(async context => + { + await clientClosedConnection.Task; + + context.Abort(); + + await serverClosedConnection.Task; + + appFuncCompleted.SetResult(null); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length}", + "", + ""); + + var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + + // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed. + await readCallbackUnwired.Task.DefaultTimeout(); + } + + clientClosedConnection.SetResult(null); + + await appFuncCompleted.Task.DefaultTimeout(); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.AtMostOnce()); + } + + [ConditionalTheory] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + [MemberData(nameof(ConnectionAdapterData))] + public async Task AppCanHandleClientAbortingConnectionMidRequest(ListenOptions listenOptions) + { + var readTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var testContext = new TestServiceContext() + { + Log = mockKestrelTrace.Object, + }; + + var scratchBuffer = new byte[4096]; + + using (var server = new TestServer(async context => + { + appStartedTcs.SetResult(null); + + try + { + await context.Request.Body.CopyToAsync(Stream.Null);; + } + catch (Exception ex) + { + readTcs.SetException(ex); + throw; + } + + readTcs.SetException(new Exception("This shouldn't be reached.")); + + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length * 2}", + "", + ""); + + await appStartedTcs.Task.DefaultTimeout(); + + await connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + + connection.Reset(); + } + + await Assert.ThrowsAnyAsync(() => readTcs.Task).DefaultTimeout(); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.AtMostOnce()); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task RequestHeadersAreResetOnEachRequest(ListenOptions listenOptions) + { + var testContext = new TestServiceContext(LoggerFactory); + + IHeaderDictionary originalRequestHeaders = null; + var firstRequest = true; + + using (var server = new TestServer(httpContext => + { + var requestFeature = httpContext.Features.Get(); + + if (firstRequest) + { + originalRequestHeaders = requestFeature.Headers; + requestFeature.Headers = new HttpRequestHeaders(); + firstRequest = false; + } + else + { + Assert.Same(originalRequestHeaders, requestFeature.Headers); + } + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task UpgradeRequestIsNotKeptAliveOrChunked(ListenOptions listenOptions) + { + const string message = "Hello World"; + + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async context => + { + var upgradeFeature = context.Features.Get(); + var duplexStream = await upgradeFeature.UpgradeAsync(); + + var buffer = new byte[message.Length]; + var read = 0; + while (read < message.Length) + { + read += await duplexStream.ReadAsync(buffer, read, buffer.Length - read).DefaultTimeout(); + } + + await duplexStream.WriteAsync(buffer, 0, read); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + message); + await connection.ReceiveForcedEnd( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {testContext.DateHeaderValue}", + "", + message); + } + } + } + + [Fact] + public async Task HeadersAndStreamsAreReusedAcrossRequests() + { + var testContext = new TestServiceContext(LoggerFactory); + var streamCount = 0; + var requestHeadersCount = 0; + var responseHeadersCount = 0; + var loopCount = 20; + Stream lastStream = null; + IHeaderDictionary lastRequestHeaders = null; + IHeaderDictionary lastResponseHeaders = null; + + using (var server = new TestServer(async context => + { + if (context.Request.Body != lastStream) + { + lastStream = context.Request.Body; + streamCount++; + } + if (context.Request.Headers != lastRequestHeaders) + { + lastRequestHeaders = context.Request.Headers; + requestHeadersCount++; + } + if (context.Response.Headers != lastResponseHeaders) + { + lastResponseHeaders = context.Response.Headers; + responseHeadersCount++; + } + + var ms = new MemoryStream(); + await context.Request.Body.CopyToAsync(ms); + var request = ms.ToArray(); + + context.Response.ContentLength = request.Length; + + await context.Response.Body.WriteAsync(request, 0, request.Length); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + var requestData = + Enumerable.Repeat("GET / HTTP/1.1\r\nHost:\r\n", loopCount) + .Concat(new[] { "GET / HTTP/1.1\r\nHost:\r\nContent-Length: 7\r\nConnection: close\r\n\r\nGoodbye" }); + + var response = string.Join("\r\n", new string[] { + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + ""}); + + var lastResponse = string.Join("\r\n", new string[] + { + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 7", + "", + "Goodbye" + }); + + var responseData = + Enumerable.Repeat(response, loopCount) + .Concat(new[] { lastResponse }); + + await connection.Send(requestData.ToArray()); + + await connection.ReceiveEnd(responseData.ToArray()); + } + + Assert.Equal(1, streamCount); + Assert.Equal(1, requestHeadersCount); + Assert.Equal(1, responseHeadersCount); + } + } + + [Theory] + [MemberData(nameof(HostHeaderData))] + public async Task MatchesValidRequestTargetAndHostHeader(string request, string hostHeader) + { + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send($"{request} HTTP/1.1", + $"Host: {hostHeader}", + "", + ""); + + await connection.Receive("HTTP/1.1 200 OK"); + } + } + } + + [Fact] + public async Task ServerConsumesKeepAliveContentLengthRequest() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "hello"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // If the server consumed the previous request properly, the + // next request should be successful + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "world"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task ServerConsumesKeepAliveChunkedRequest() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "5", + "hello", + "5", + "world", + "0", + "Trailer: value", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // If the server consumed the previous request properly, the + // next request should be successful + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "world"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task NonKeepAliveRequestNotConsumedByAppCompletes() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "hello"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task UpgradedRequestNotConsumedByAppCompletes() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(async context => + { + var upgradeFeature = context.Features.Get(); + var duplexStream = await upgradeFeature.UpgradeAsync(); + + var response = Encoding.ASCII.GetBytes("goodbye"); + await duplexStream.WriteAsync(response, 0, response.Length); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "GET / HTTP/1.1", + "Host:", + "Connection: upgrade", + "", + "hello"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + "goodbye"); + } + } + } + + [Fact] + public async Task DoesNotEnforceRequestBodyMinimumDataRateOnUpgradedRequest() + { + var appEvent = new TaskCompletionSource(); + var delayEvent = new TaskCompletionSource(); + var serviceContext = new TestServiceContext(LoggerFactory) + { + SystemClock = new SystemClock() + }; + + using (var server = new TestServer(async context => + { + context.Features.Get().MinDataRate = + new MinDataRate(bytesPerSecond: double.MaxValue, gracePeriod: Heartbeat.Interval + TimeSpan.FromTicks(1)); + + using (var stream = await context.Features.Get().UpgradeAsync()) + { + appEvent.SetResult(null); + + // Read once to go through one set of TryPauseTimingReads()/TryResumeTimingReads() calls + await stream.ReadAsync(new byte[1], 0, 1); + + await delayEvent.Task.DefaultTimeout(); + + // Read again to check that the connection is still alive + await stream.ReadAsync(new byte[1], 0, 1); + + // Send a response to distinguish from the timeout case where the 101 is still received, but without any content + var response = Encoding.ASCII.GetBytes("hello"); + await stream.WriteAsync(response, 0, response.Length); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: upgrade", + "", + "a"); + + await appEvent.Task.DefaultTimeout(); + + await Task.Delay(TimeSpan.FromSeconds(5)); + + delayEvent.SetResult(null); + + await connection.Send("b"); + + await connection.Receive( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + ""); + await connection.ReceiveStartsWith( + $"Date: "); + await connection.ReceiveForcedEnd( + "", + "hello"); + } + } + } + + [Fact] + public async Task SynchronousReadsAllowedByDefault() + { + var firstRequest = true; + + using (var server = new TestServer(async context => + { + var bodyControlFeature = context.Features.Get(); + Assert.True(bodyControlFeature.AllowSynchronousIO); + + var buffer = new byte[6]; + var offset = 0; + + // The request body is 5 bytes long. The 6th byte (buffer[5]) is only used for writing the response body. + buffer[5] = (byte)(firstRequest ? '1' : '2'); + + if (firstRequest) + { + while (offset < 5) + { + offset += context.Request.Body.Read(buffer, offset, 5 - offset); + } + + firstRequest = false; + } + else + { + bodyControlFeature.AllowSynchronousIO = false; + + // Synchronous reads now throw. + var ioEx = Assert.Throws(() => context.Request.Body.Read(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.SynchronousReadsDisallowed, ioEx.Message); + + var ioEx2 = Assert.Throws(() => context.Request.Body.CopyTo(Stream.Null)); + Assert.Equal(CoreStrings.SynchronousReadsDisallowed, ioEx2.Message); + + while (offset < 5) + { + offset += await context.Request.Body.ReadAsync(buffer, offset, 5 - offset); + } + } + + Assert.Equal(0, await context.Request.Body.ReadAsync(new byte[1], 0, 1)); + Assert.Equal("Hello", Encoding.ASCII.GetString(buffer, 0, 5)); + + context.Response.ContentLength = 6; + await context.Response.Body.WriteAsync(buffer, 0, 6); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "HelloPOST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "Hello"); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 6", + "", + "Hello1HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 6", + "", + "Hello2"); + } + } + } + + [Fact] + public async Task SynchronousReadsCanBeDisallowedGlobally() + { + var testContext = new TestServiceContext(LoggerFactory) + { + ServerOptions = { AllowSynchronousIO = false } + }; + + using (var server = new TestServer(async context => + { + var bodyControlFeature = context.Features.Get(); + Assert.False(bodyControlFeature.AllowSynchronousIO); + + // Synchronous reads now throw. + var ioEx = Assert.Throws(() => context.Request.Body.Read(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.SynchronousReadsDisallowed, ioEx.Message); + + var ioEx2 = Assert.Throws(() => context.Request.Body.CopyTo(Stream.Null)); + Assert.Equal(CoreStrings.SynchronousReadsDisallowed, ioEx2.Message); + + var buffer = new byte[5]; + var offset = 0; + while (offset < 5) + { + offset += await context.Request.Body.ReadAsync(buffer, offset, 5 - offset); + } + + Assert.Equal(0, await context.Request.Body.ReadAsync(new byte[1], 0, 1)); + Assert.Equal("Hello", Encoding.ASCII.GetString(buffer)); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "Hello"); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress) + { + var builder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls($"http://{registerAddress}:0") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + var connection = context.Connection; + await context.Response.WriteAsync(JsonConvert.SerializeObject(new + { + RemoteIPAddress = connection.RemoteIpAddress?.ToString(), + RemotePort = connection.RemotePort, + LocalIPAddress = connection.LocalIpAddress?.ToString(), + LocalPort = connection.LocalPort + })); + }); + }); + + using (var host = builder.Build()) + using (var client = new HttpClient()) + { + host.Start(); + + var response = await client.GetAsync($"http://{requestAddress}:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + + var connectionFacts = await response.Content.ReadAsStringAsync(); + Assert.NotEmpty(connectionFacts); + + var facts = JsonConvert.DeserializeObject(connectionFacts); + Assert.Equal(expectAddress, facts["RemoteIPAddress"].Value()); + Assert.NotEmpty(facts["RemotePort"].Value()); + } + } + + public static TheoryData HostHeaderData => HttpParsingData.HostHeaderData; + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs new file mode 100644 index 0000000000..39b6725ee3 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs @@ -0,0 +1,3363 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Microsoft.Extensions.Primitives; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class ResponseTests : TestApplicationErrorLoggerLoggedTest + { + public static TheoryData ConnectionAdapterData => new TheoryData + { + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = { new PassThroughConnectionAdapter() } + } + }; + + [Fact] + public async Task LargeDownload() + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + var bytes = new byte[1024]; + for (int i = 0; i < bytes.Length; i++) + { + bytes[i] = (byte)i; + } + + context.Response.ContentLength = bytes.Length * 1024; + + for (int i = 0; i < 1024; i++) + { + await context.Response.Body.WriteAsync(bytes, 0, bytes.Length); + } + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + var responseBody = await response.Content.ReadAsStreamAsync(); + + // Read the full response body + var total = 0; + var bytes = new byte[1024]; + var count = await responseBody.ReadAsync(bytes, 0, bytes.Length); + while (count > 0) + { + for (int i = 0; i < count; i++) + { + Assert.Equal(total % 256, bytes[i]); + total++; + } + count = await responseBody.ReadAsync(bytes, 0, bytes.Length); + } + } + } + } + + [Theory, MemberData(nameof(NullHeaderData))] + public async Task IgnoreNullHeaderValues(string headerName, StringValues headerValue, string expectedValue) + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + context.Response.Headers.Add(headerName, headerValue); + + await context.Response.WriteAsync(""); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + + var headers = response.Headers; + + if (expectedValue == null) + { + Assert.False(headers.Contains(headerName)); + } + else + { + Assert.True(headers.Contains(headerName)); + Assert.Equal(headers.GetValues(headerName).Single(), expectedValue); + } + } + } + } + + [Fact] + public async Task OnCompleteCalledEvenWhenOnStartingNotCalled() + { + var onStartingCalled = false; + var onCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(context => + { + context.Response.OnStarting(() => Task.Run(() => onStartingCalled = true)); + context.Response.OnCompleted(() => Task.Run(() => + { + onCompletedTcs.SetResult(null); + })); + + // Prevent OnStarting call (see HttpProtocol.ProcessRequestsAsync()). + throw new Exception(); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.False(onStartingCalled); + await onCompletedTcs.Task.DefaultTimeout(); + } + } + } + + [Fact] + public async Task OnStartingThrowsWhenSetAfterResponseHasAlreadyStarted() + { + InvalidOperationException ex = null; + + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + await context.Response.WriteAsync("hello, world"); + await context.Response.Body.FlushAsync(); + ex = Assert.Throws(() => context.Response.OnStarting(_ => Task.CompletedTask, null)); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + + // Despite the error, the response had already started + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.NotNull(ex); + } + } + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposeAppException() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + context => + { + throw new Exception(); + }, + expectedClientStatusCode: HttpStatusCode.InternalServerError, + expectedServerStatusCode: HttpStatusCode.InternalServerError); + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposeRequestAborted() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + context => + { + context.Abort(); + return Task.CompletedTask; + }, + expectedClientStatusCode: null, + expectedServerStatusCode: 0); + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposeRequestAbortedAppException() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + context => + { + context.Abort(); + throw new Exception(); + }, + expectedClientStatusCode: null, + expectedServerStatusCode: 0); + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposedRequestMalformed() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + context => + { + return Task.CompletedTask; + }, + expectedClientStatusCode: HttpStatusCode.OK, + expectedServerStatusCode: HttpStatusCode.OK, + sendMalformedRequest: true); + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposedRequestMalformedRead() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + async context => + { + await context.Request.Body.ReadAsync(new byte[1], 0, 1); + }, + expectedClientStatusCode: null, + expectedServerStatusCode: HttpStatusCode.BadRequest, + sendMalformedRequest: true); + } + + [Fact] + public Task ResponseStatusCodeSetBeforeHttpContextDisposedRequestMalformedReadIgnored() + { + return ResponseStatusCodeSetBeforeHttpContextDispose( + TestSink, + LoggerFactory, + async context => + { + try + { + await context.Request.Body.ReadAsync(new byte[1], 0, 1); + } + catch (BadHttpRequestException) + { + } + }, + expectedClientStatusCode: HttpStatusCode.OK, + expectedServerStatusCode: HttpStatusCode.OK, + sendMalformedRequest: true); + } + + [Fact] + public async Task OnCompletedExceptionShouldNotPreventAResponse() + { + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + context.Response.OnCompleted(_ => throw new Exception(), null); + await context.Response.WriteAsync("hello, world"); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + } + } + + [Fact] + public async Task OnCompletedShouldNotBlockAResponse() + { + var delayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var hostBuilder = TransportSelector.GetWebHostBuilder() + .UseKestrel() + .UseUrls("http://127.0.0.1:0/") + .ConfigureServices(AddTestLogging) + .Configure(app => + { + app.Run(async context => + { + context.Response.OnCompleted(async () => + { + await delayTcs.Task; + }); + await context.Response.WriteAsync("hello, world"); + }); + }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var client = new HttpClient()) + { + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + delayTcs.SetResult(null); + } + } + + [Fact] + public async Task InvalidChunkedEncodingInRequestShouldNotBlockOnCompleted() + { + var onCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnCompleted(() => Task.Run(() => + { + onCompletedTcs.SetResult(null); + })); + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "gg"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + await onCompletedTcs.Task.DefaultTimeout(); + } + + private static async Task ResponseStatusCodeSetBeforeHttpContextDispose( + ITestSink testSink, + ILoggerFactory loggerFactory, + RequestDelegate handler, + HttpStatusCode? expectedClientStatusCode, + HttpStatusCode expectedServerStatusCode, + bool sendMalformedRequest = false) + { + var mockHttpContextFactory = new Mock(); + mockHttpContextFactory.Setup(f => f.Create(It.IsAny())) + .Returns(fc => new DefaultHttpContext(fc)); + + var disposedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + mockHttpContextFactory.Setup(f => f.Dispose(It.IsAny())) + .Callback(c => + { + disposedTcs.TrySetResult(c.Response.StatusCode); + }); + + using (var server = new TestServer(handler, new TestServiceContext(loggerFactory), + new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), + services => services.AddSingleton(mockHttpContextFactory.Object))) + { + if (!sendMalformedRequest) + { + using (var client = new HttpClient()) + { + try + { + var response = await client.GetAsync($"http://127.0.0.1:{server.Port}/"); + Assert.Equal(expectedClientStatusCode, response.StatusCode); + } + catch + { + if (expectedClientStatusCode != null) + { + throw; + } + } + } + } + else + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "gg"); + if (expectedClientStatusCode == HttpStatusCode.OK) + { + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + else + { + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + var disposedStatusCode = await disposedTcs.Task.DefaultTimeout(); + Assert.Equal(expectedServerStatusCode, (HttpStatusCode)disposedStatusCode); + } + + if (sendMalformedRequest) + { + Assert.Contains(testSink.Writes, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + else + { + Assert.DoesNotContain(testSink.Writes, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + } + + // https://github.com/aspnet/KestrelHttpServer/pull/1111/files#r80584475 explains the reason for this test. + [Fact] + public async Task NoErrorResponseSentWhenAppSwallowsBadRequestException() + { + BadHttpRequestException readException = null; + + using (var server = new TestServer(async httpContext => + { + readException = await Assert.ThrowsAsync( + async () => await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1)); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "gg"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.NotNull(readException); + + Assert.Contains(TestSink.Writes, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + + [Fact] + public async Task TransferEncodingChunkedSetOnUnknownLengthHttp11Response() + { + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync("hello, "); + await httpContext.Response.WriteAsync("world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "7", + "hello, ", + "5", + "world", + "0", + "", + ""); + } + } + } + + [Theory] + [InlineData(StatusCodes.Status204NoContent)] + [InlineData(StatusCodes.Status205ResetContent)] + [InlineData(StatusCodes.Status304NotModified)] + public async Task TransferEncodingChunkedNotSetOnNonBodyResponse(int statusCode) + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.StatusCode = statusCode; + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 {Encoding.ASCII.GetString(ReasonPhrases.ToStatusBytes(statusCode))}", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + } + } + } + + [Fact] + public async Task TransferEncodingNotSetOnHeadResponse() + { + using (var server = new TestServer(httpContext => + { + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + } + } + } + + [Fact] + public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce() + { + const string response = "hello, world"; + + var logTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.ConnectionHeadResponseBodyWrite(It.IsAny(), response.Length)) + .Callback((connectionId, count) => logTcs.SetResult(null)); + + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync(response); + await httpContext.Response.Body.FlushAsync(); + }, new TestServiceContext(LoggerFactory, mockKestrelTrace.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + + // Wait for message to be logged before disposing the socket. + // Disposing the socket will abort the connection and HttpProtocol._requestAborted + // might be 1 by the time ProduceEnd() gets called and the message is logged. + await logTcs.Task.DefaultTimeout(); + } + } + + mockKestrelTrace.Verify(kestrelTrace => + kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), response.Length), Times.Once); + } + + [Fact] + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWrite() + { + var serviceContext = new TestServiceContext(LoggerFactory) + { + ServerOptions = { AllowSynchronousIO = true } + }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 11; + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes("hello,"), 0, 6); + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes(" world"), 0, 6); + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + + await connection.WaitForConnectionClose().DefaultTimeout(); + } + } + + var logMessage = Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + + } + + [Fact] + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWriteAsync() + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello,"); + await httpContext.Response.WriteAsync(" world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + } + } + + var logMessage = Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + } + + [Fact] + public async Task InternalServerErrorAndConnectionClosedOnWriteWithMoreThanContentLengthAndResponseNotStarted() + { + var serviceContext = new TestServiceContext(LoggerFactory) + { + ServerOptions = { AllowSynchronousIO = true } + }; + + using (var server = new TestServer(httpContext => + { + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = 5; + httpContext.Response.Body.Write(response, 0, response.Length); + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task InternalServerErrorAndConnectionClosedOnWriteAsyncWithMoreThanContentLengthAndResponseNotStarted() + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(httpContext => + { + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = 5; + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesLessThanContentLengthErrorLogged() + { + var logTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockTrace = new Mock(); + mockTrace + .Setup(trace => trace.ApplicationError(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((connectionId, requestId, ex) => + { + logTcs.SetResult(null); + }); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory, mockTrace.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + // Don't use ReceiveEnd here, otherwise the FIN might + // abort the request before the server checks the + // response content length, in which case the check + // will be skipped. + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 13", + "", + "hello, world"); + + // Wait for error message to be logged. + await logTcs.Task.DefaultTimeout(); + + // The server should close the connection in this situation. + await connection.WaitForConnectionClose().DefaultTimeout(); + } + } + + mockTrace.Verify(trace => + trace.ApplicationError( + It.IsAny(), + It.IsAny(), + It.Is(ex => + ex.Message.Equals($"Response Content-Length mismatch: too few bytes written (12 of 13).", StringComparison.Ordinal)))); + } + + [Fact] + public async Task WhenAppWritesLessThanContentLengthButRequestIsAbortedErrorNotLogged() + { + var requestAborted = new SemaphoreSlim(0); + var mockTrace = new Mock(); + + using (var server = new TestServer(async httpContext => + { + httpContext.RequestAborted.Register(() => + { + requestAborted.Release(2); + }); + + httpContext.Response.ContentLength = 12; + await httpContext.Response.WriteAsync("hello,"); + + // Wait until the request is aborted so we know HttpProtocol will skip the response content length check. + Assert.True(await requestAborted.WaitAsync(TestConstants.DefaultTimeout)); + }, new TestServiceContext(LoggerFactory, mockTrace.Object))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 12", + "", + "hello,"); + } + + // Verify the request was really aborted. A timeout in + // the app would cause a server error and skip the content length + // check altogether, making the test pass for the wrong reason. + // Await before disposing the server to prevent races between the + // abort triggered by the connection RST and the abort called when + // disposing the server. + Assert.True(await requestAborted.WaitAsync(TestConstants.DefaultTimeout)); + } + + // With the server disposed we know all connections were drained and all messages were logged. + mockTrace.Verify(trace => trace.ApplicationError(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + [Fact] + public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionDoesNotClose() + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 5; + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 500 Internal Server Error", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var error = TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error); + Assert.Equal(2, error.Count()); + Assert.All(error, message => message.Message.Equals("Response Content-Length mismatch: too few bytes written (0 of 5).")); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task WhenAppSetsContentLengthToZeroAndDoesNotWriteNoErrorIsThrown(bool flushResponse) + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 0; + + if (flushResponse) + { + await httpContext.Response.Body.FlushAsync(); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Empty(TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error)); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingLessIsNotAnError() + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 13", + "Transfer-Encoding: chunked", + "", + "hello, world"); + } + } + + Assert.Empty(TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error)); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingMoreIsNotAnError() + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "Transfer-Encoding: chunked", + "", + "hello, world"); + } + } + + Assert.Empty(TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error)); + } + + [Fact] + public async Task HeadResponseCanContainContentLengthHeader() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 42; + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 42", + "", + ""); + } + } + } + + [Fact] + public async Task HeadResponseBodyNotWrittenWithAsyncWrite() + { + var flushed = new SemaphoreSlim(0, 1); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 12; + await httpContext.Response.WriteAsync("hello, world"); + await flushed.WaitAsync(); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 12", + "", + ""); + + flushed.Release(); + } + } + } + + [Fact] + public async Task HeadResponseBodyNotWrittenWithSyncWrite() + { + var flushed = new SemaphoreSlim(0, 1); + var serviceContext = new TestServiceContext(LoggerFactory) { ServerOptions = { AllowSynchronousIO = true } }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 12; + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes("hello, world"), 0, 12); + await flushed.WaitAsync(); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 12", + "", + ""); + + flushed.Release(); + } + } + } + + [Fact] + public async Task ZeroLengthWritesFlushHeaders() + { + var flushed = new SemaphoreSlim(0, 1); + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 12; + await httpContext.Response.WriteAsync(""); + await flushed.WaitAsync(); + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 12", + "", + ""); + + flushed.Release(); + + await connection.ReceiveEnd("hello, world"); + } + } + } + + [Fact] + public async Task WriteAfterConnectionCloseNoops() + { + var connectionClosed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async httpContext => + { + try + { + requestStarted.SetResult(null); + await connectionClosed.Task.DefaultTimeout(); + httpContext.Response.ContentLength = 12; + await httpContext.Response.WriteAsync("hello, world"); + appCompleted.TrySetResult(null); + } + catch (Exception ex) + { + appCompleted.TrySetException(ex); + } + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await requestStarted.Task.DefaultTimeout(); + connection.Shutdown(SocketShutdown.Send); + await connection.WaitForConnectionClose().DefaultTimeout(); + } + + connectionClosed.SetResult(null); + + await appCompleted.Task.DefaultTimeout(); + } + } + + [Fact] + public async Task AppCanWriteOwnBadRequestResponse() + { + var expectedResponse = string.Empty; + var responseWritten = new SemaphoreSlim(0); + + using (var server = new TestServer(async httpContext => + { + try + { + await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + } + catch (BadHttpRequestException ex) + { + expectedResponse = ex.Message; + httpContext.Response.StatusCode = StatusCodes.Status400BadRequest; + httpContext.Response.ContentLength = ex.Message.Length; + await httpContext.Response.WriteAsync(ex.Message); + responseWritten.Release(); + } + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "gg"); + await responseWritten.WaitAsync().DefaultTimeout(); + await connection.ReceiveEnd( + "HTTP/1.1 400 Bad Request", + $"Date: {server.Context.DateHeaderValue}", + $"Content-Length: {expectedResponse.Length}", + "", + expectedResponse); + } + } + } + + [Theory] + [InlineData("gzip")] + [InlineData("chunked, gzip")] + public async Task ConnectionClosedWhenChunkedIsNotFinalTransferCoding(string responseTransferEncoding) + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = responseTransferEncoding; + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + } + } + } + + [Theory] + [InlineData("gzip")] + [InlineData("chunked, gzip")] + public async Task ConnectionClosedWhenChunkedIsNotFinalTransferCodingEvenIfConnectionKeepAliveSetInResponse(string responseTransferEncoding) + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Connection"] = "keep-alive"; + httpContext.Response.Headers["Transfer-Encoding"] = responseTransferEncoding; + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + } + } + } + + [Theory] + [InlineData("chunked")] + [InlineData("gzip, chunked")] + public async Task ConnectionKeptAliveWhenChunkedIsFinalTransferCoding(string responseTransferEncoding) + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = responseTransferEncoding; + + // App would have to chunk manually, but here we don't care + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + + // Make sure connection was kept open + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: {responseTransferEncoding}", + "", + "hello, world"); + } + } + } + + [Fact] + public async Task FirstWriteVerifiedAfterOnStarting() + { + var serviceContext = new TestServiceContext(LoggerFactory) { ServerOptions = { AllowSynchronousIO = true } }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return Task.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length); + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteVerifiedAfterOnStarting() + { + var serviceContext = new TestServiceContext(LoggerFactory) { ServerOptions = { AllowSynchronousIO = true } }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return Task.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length / 2); + httpContext.Response.Body.Write(response, response.Length / 2, response.Length - response.Length / 2); + return Task.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task FirstWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return Task.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return Task.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + await httpContext.Response.Body.WriteAsync(response, 0, response.Length / 2); + await httpContext.Response.Body.WriteAsync(response, response.Length / 2, response.Length - response.Length / 2); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task WhenResponseAlreadyStartedResponseEndedBeforeConsumingRequestBody() + { + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + ""); + + // If the expected behavior is regressed, this will hang because the + // server will try to consume the request body before flushing the chunked + // terminator. + await connection.Receive( + "0", + "", + ""); + } + } + } + + [Fact] + public async Task WhenResponseNotStartedResponseEndedBeforeConsumingRequestBody() + { + using (var server = new TestServer(httpContext => Task.CompletedTask, + new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "gg"); + + // This will receive a success response because the server flushed the response + // before reading the malformed chunk header in the request, but then it will close + // the connection. + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Contains(TestApplicationErrorLogger.Messages, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + + [Fact] + public async Task RequestDrainingFor100ContinueDoesNotBlockResponse() + { + var foundMessage = false; + using (var server = new TestServer(httpContext => + { + return httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "Expect: 100-continue", + "", + ""); + + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + + // Let the app finish + await connection.Send( + "1", + "a", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // This will be consumed by Http1Connection when it attempts to + // consume the request body and will cause an error. + await connection.Send( + "gg"); + + // Wait for the server to drain the request body and log an error. + // Time out after 10 seconds + for (int i = 0; i < 10 && !foundMessage; i++) + { + while (TestApplicationErrorLogger.Messages.TryDequeue(out var message)) + { + if (message.EventId.Id == 17 && message.LogLevel == LogLevel.Information && message.Exception is BadHttpRequestException + && ((BadHttpRequestException)message.Exception).StatusCode == StatusCodes.Status400BadRequest) + { + foundMessage = true; + break; + } + } + + if (!foundMessage) + { + await Task.Delay(TimeSpan.FromSeconds(1)); + } + } + + await connection.ReceiveEnd(); + } + } + + Assert.True(foundMessage, "Expected log not found"); + } + + [Fact] + public async Task Sending100ContinueDoesNotPreventAutomatic400Responses() + { + using (var server = new TestServer(httpContext => + { + return httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "Expect: 100-continue", + "", + ""); + + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + + // Send an invalid chunk prefix to cause an error. + await connection.Send( + "gg"); + + // If 100 Continue sets HttpProtocol.HasResponseStarted to true, + // a success response will be produced before the server sees the + // bad chunk header above, making this test fail. + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Contains(TestApplicationErrorLogger.Messages, w => w.EventId.Id == 17 && w.LogLevel == LogLevel.Information && w.Exception is BadHttpRequestException + && ((BadHttpRequestException)w.Exception).StatusCode == StatusCodes.Status400BadRequest); + } + + [Fact] + public async Task Sending100ContinueAndResponseSendsChunkTerminatorBeforeConsumingRequestBody() + { + using (var server = new TestServer(async httpContext => + { + await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1); + await httpContext.Response.WriteAsync("hello, world"); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 2", + "Expect: 100-continue", + "", + ""); + + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + + await connection.Send( + "a"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + ""); + + // If the expected behavior is regressed, this will hang because the + // server will try to consume the request body before flushing the chunked + // terminator. + await connection.Receive( + "0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task Http11ResponseSentToHttp10Request(ListenOptions listenOptions) + { + var serviceContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 11", + "", + "Hello World"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {serviceContext.DateHeaderValue}", + "", + "Hello World"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroContentLengthSetAutomaticallyAfterNoWrites(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EmptyApp, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.0", + "Connection: keep-alive", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 200 OK", + "Connection: keep-alive", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroContentLengthSetAutomaticallyForNonKeepAliveRequests(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + Assert.Equal(0, await httpContext.Request.Body.ReadAsync(new byte[1], 0, 1).DefaultTimeout()); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroContentLengthNotSetAutomaticallyForHeadRequests(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(TestApp.EmptyApp, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "HEAD / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ZeroContentLengthNotSetAutomaticallyForCertainStatusCodes(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var request = httpContext.Request; + var response = httpContext.Response; + + using (var reader = new StreamReader(request.Body, Encoding.ASCII)) + { + var statusString = await reader.ReadLineAsync(); + response.StatusCode = int.Parse(statusString); + } + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 3", + "", + "204POST / HTTP/1.1", + "Host:", + "Content-Length: 3", + "", + "205POST / HTTP/1.1", + "Host:", + "Content-Length: 3", + "", + "304POST / HTTP/1.1", + "Host:", + "Content-Length: 3", + "", + "200"); + await connection.ReceiveEnd( + "HTTP/1.1 204 No Content", + $"Date: {testContext.DateHeaderValue}", + "", + "HTTP/1.1 205 Reset Content", + $"Date: {testContext.DateHeaderValue}", + "", + "HTTP/1.1 304 Not Modified", + $"Date: {testContext.DateHeaderValue}", + "", + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ConnectionClosedAfter101Response(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var request = httpContext.Request; + var stream = await httpContext.Features.Get().UpgradeAsync(); + var response = Encoding.ASCII.GetBytes("hello, world"); + await stream.WriteAsync(response, 0, response.Length); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {testContext.DateHeaderValue}", + "", + "hello, world"); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "Connection: keep-alive, Upgrade", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {testContext.DateHeaderValue}", + "", + "hello, world"); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowingResultsIn500Response(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + bool onStartingCalled = false; + + using (var server = new TestServer(httpContext => + { + var response = httpContext.Response; + response.OnStarting(_ => + { + onStartingCalled = true; + return Task.CompletedTask; + }, null); + + // Anything added to the ResponseHeaders dictionary is ignored + response.Headers["Content-Length"] = "11"; + throw new Exception(); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 500 Internal Server Error", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.False(onStartingCalled); + Assert.Equal(2, TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error).Count()); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowingInOnStartingResultsInFailedWritesAnd500Response(ListenOptions listenOptions) + { + var callback1Called = false; + var callback2CallCount = 0; + + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var onStartingException = new Exception(); + + var response = httpContext.Response; + response.OnStarting(_ => + { + callback1Called = true; + throw onStartingException; + }, null); + response.OnStarting(_ => + { + callback2CallCount++; + throw onStartingException; + }, null); + + var writeException = await Assert.ThrowsAsync(async () => await response.Body.FlushAsync()); + Assert.Same(onStartingException, writeException.InnerException); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 500 Internal Server Error", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 500 Internal Server Error", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + // The first registered OnStarting callback should have been called, + // since they are called LIFO order and the other one failed. + Assert.False(callback1Called); + Assert.Equal(2, callback2CallCount); + Assert.Equal(2, TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error).Count()); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowingInOnCompletedIsLogged(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + var onCompletedCalled1 = false; + var onCompletedCalled2 = false; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.OnCompleted(_ => + { + onCompletedCalled1 = true; + throw new Exception(); + }, null); + response.OnCompleted(_ => + { + onCompletedCalled2 = true; + throw new Exception(); + }, null); + + response.Headers["Content-Length"] = new[] { "11" }; + + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + + // All OnCompleted callbacks should be called even if they throw. + Assert.Equal(2, TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error).Count()); + Assert.True(onCompletedCalled1); + Assert.True(onCompletedCalled2); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowingAfterWritingKillsConnection(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + bool onStartingCalled = false; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.OnStarting(_ => + { + onStartingCalled = true; + return Task.FromResult(null); + }, null); + + response.Headers["Content-Length"] = new[] { "11" }; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + throw new Exception(); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + + Assert.True(onStartingCalled); + Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowingAfterPartialWriteKillsConnection(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + bool onStartingCalled = false; + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.OnStarting(_ => + { + onStartingCalled = true; + return Task.FromResult(null); + }, null); + + response.Headers["Content-Length"] = new[] { "11" }; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello"), 0, 5); + throw new Exception(); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello"); + } + } + + Assert.True(onStartingCalled); + Assert.Single(TestApplicationErrorLogger.Messages, message => message.LogLevel == LogLevel.Error); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ThrowsOnWriteWithRequestAbortedTokenAfterRequestIsAborted(ListenOptions listenOptions) + { + // This should match _maxBytesPreCompleted in SocketOutput + var maxBytesPreCompleted = 65536; + + // Ensure string is long enough to disable write-behind buffering + var largeString = new string('a', maxBytesPreCompleted + 1); + + var writeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestAbortedWh = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestStartWh = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async httpContext => + { + requestStartWh.SetResult(null); + + var response = httpContext.Response; + var request = httpContext.Request; + var lifetime = httpContext.Features.Get(); + + lifetime.RequestAborted.Register(() => requestAbortedWh.SetResult(null)); + await requestAbortedWh.Task.DefaultTimeout(); + + try + { + await response.WriteAsync(largeString, lifetime.RequestAborted); + } + catch (Exception ex) + { + writeTcs.SetException(ex); + throw; + } + + writeTcs.SetException(new Exception("This shouldn't be reached.")); + }, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 0", + "", + ""); + + await requestStartWh.Task.DefaultTimeout(); + } + + // Write failed - can throw TaskCanceledException or OperationCanceledException, + // depending on how far the canceled write goes. + await Assert.ThrowsAnyAsync(async () => await writeTcs.Task).DefaultTimeout(); + + // RequestAborted tripped + await requestAbortedWh.Task.DefaultTimeout(); + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task WritingToConnectionAfterUnobservedCloseTriggersRequestAbortedToken(ListenOptions listenOptions) + { + const int connectionPausedEventId = 4; + const int maxRequestBufferSize = 2048; + + var requestAborted = false; + var readCallbackUnwired = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientClosedConnection = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var writeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + if (eventId.Id == connectionPausedEventId) + { + readCallbackUnwired.TrySetResult(null); + } + + Logger.Log(logLevel, eventId, state, exception, formatter); + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + var testContext = new TestServiceContext(mockLoggerFactory.Object) + { + Log = mockKestrelTrace.Object, + ServerOptions = + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = maxRequestBufferSize, + MaxRequestHeadersTotalSize = maxRequestBufferSize, + } + } + }; + + var scratchBuffer = new byte[maxRequestBufferSize * 8]; + + using (var server = new TestServer(async context => + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + await clientClosedConnection.Task; + + try + { + for (var i = 0; i < 1000; i++) + { + await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length, context.RequestAborted); + await Task.Delay(10); + } + } + catch (Exception ex) + { + writeTcs.SetException(ex); + throw; + } + + writeTcs.SetException(new Exception("This shouldn't be reached.")); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length}", + "", + ""); + + var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + + // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed. + await readCallbackUnwired.Task.DefaultTimeout(); + } + + clientClosedConnection.SetResult(null); + + await Assert.ThrowsAnyAsync(() => writeTcs.Task).DefaultTimeout(); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.AtMostOnce()); + Assert.True(requestAborted); + } + + [ConditionalTheory] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + [MemberData(nameof(ConnectionAdapterData))] + public async Task AppCanHandleClientAbortingConnectionMidResponse(ListenOptions listenOptions) + { + const int connectionResetEventId = 19; + const int connectionFinEventId = 6; + //const int connectionStopEventId = 2; + + const int responseBodySegmentSize = 65536; + const int responseBodySegmentCount = 100; + + var appCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestAborted = false; + + var scratchBuffer = new byte[responseBodySegmentSize]; + + using (var server = new TestServer(async context => + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + for (var i = 0; i < responseBodySegmentCount; i++) + { + await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + await Task.Delay(10); + } + + appCompletedTcs.SetResult(null); + }, new TestServiceContext(LoggerFactory), listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + // Read just part of the response and close the connection. + // https://github.com/aspnet/KestrelHttpServer/issues/2554 + await connection.Stream.ReadAsync(scratchBuffer, 0, scratchBuffer.Length); + + connection.Reset(); + } + + await appCompletedTcs.Task.DefaultTimeout(); + + // After the app is done with the write loop, the connection reset should be logged. + // On Linux and macOS, the connection close is still sometimes observed as a FIN despite the LingerState. + var presShutdownTransportLogs = TestSink.Writes.Where( + w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" || + w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"); + var connectionResetLogs = presShutdownTransportLogs.Where( + w => w.EventId == connectionResetEventId || + (!TestPlatformHelper.IsWindows && w.EventId == connectionFinEventId)); + + Assert.NotEmpty(connectionResetLogs); + } + + // TODO: Figure out what the following assertion is flaky. The server shouldn't shutdown before all + // the connections are closed, yet sometimes the connection stop log isn't observed here. + //var coreLogs = TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel"); + //Assert.Single(coreLogs.Where(w => w.EventId == connectionStopEventId)); + + Assert.True(requestAborted, "RequestAborted token didn't fire."); + + var transportLogs = TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" || + w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"); + Assert.Empty(transportLogs.Where(w => w.LogLevel > LogLevel.Debug)); + } + + [ConditionalTheory] + [OSSkipCondition(OperatingSystems.MacOSX, SkipReason = "macOS EPIPE vs. EPROTOTYPE bug https://github.com/aspnet/KestrelHttpServer/issues/2885")] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ClientAbortingConnectionImmediatelyIsNotLoggedHigherThanDebug(ListenOptions listenOptions) + { + // Attempt multiple connections to be extra sure the resets are consistently logged appropriately. + const int numConnections = 10; + + // There's not guarantee that the app even gets invoked in this test. The connection reset can be observed + // as early as accept. + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory), listenOptions)) + { + for (var i = 0; i < numConnections; i++) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + connection.Reset(); + } + } + } + + var transportLogs = TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" || + w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"); + + Assert.Empty(transportLogs.Where(w => w.LogLevel > LogLevel.Debug)); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.Headers["Content-Length"] = new[] { "11" }; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.0", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 11", + "", + "Hello World"); + } + } + + Assert.Empty(TestApplicationErrorLogger.Messages.Where(message => message.LogLevel == LogLevel.Error)); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task NoResponseSentWhenConnectionIsClosedByServerBeforeClientFinishesSendingRequest(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(httpContext => + { + httpContext.Abort(); + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 1", + "", + ""); + await connection.ReceiveForcedEnd(); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ResponseHeadersAreResetOnEachRequest(ListenOptions listenOptions) + { + var testContext= new TestServiceContext(LoggerFactory); + + IHeaderDictionary originalResponseHeaders = null; + var firstRequest = true; + + using (var server = new TestServer(httpContext => + { + var responseFeature = httpContext.Features.Get(); + + if (firstRequest) + { + originalResponseHeaders = responseFeature.Headers; + responseFeature.Headers = new HttpResponseHeaders(); + firstRequest = false; + } + else + { + Assert.Same(originalResponseHeaders, responseFeature.Headers); + } + + return Task.CompletedTask; + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task OnStartingCallbacksAreCalledInLastInFirstOutOrder(ListenOptions listenOptions) + { + const string response = "hello, world"; + + var testContext= new TestServiceContext(LoggerFactory); + + var callOrder = new Stack(); + var onStartingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + context.Response.OnStarting(_ => + { + callOrder.Push(1); + onStartingTcs.SetResult(null); + return Task.CompletedTask; + }, null); + context.Response.OnStarting(_ => + { + callOrder.Push(2); + return Task.CompletedTask; + }, null); + + context.Response.ContentLength = response.Length; + await context.Response.WriteAsync(response); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + $"Content-Length: {response.Length}", + "", + "hello, world"); + + // Wait for all callbacks to be called. + await onStartingTcs.Task.DefaultTimeout(); + } + } + + Assert.Equal(1, callOrder.Pop()); + Assert.Equal(2, callOrder.Pop()); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task OnCompletedCallbacksAreCalledInLastInFirstOutOrder(ListenOptions listenOptions) + { + const string response = "hello, world"; + + var testContext= new TestServiceContext(LoggerFactory); + + var callOrder = new Stack(); + var onCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + context.Response.OnCompleted(_ => + { + callOrder.Push(1); + onCompletedTcs.SetResult(null); + return Task.CompletedTask; + }, null); + context.Response.OnCompleted(_ => + { + callOrder.Push(2); + return Task.CompletedTask; + }, null); + + context.Response.ContentLength = response.Length; + await context.Response.WriteAsync(response); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + $"Content-Length: {response.Length}", + "", + "hello, world"); + + // Wait for all callbacks to be called. + await onCompletedTcs.Task.DefaultTimeout(); + } + } + + Assert.Equal(1, callOrder.Pop()); + Assert.Equal(2, callOrder.Pop()); + } + + + [Fact] + public async Task SynchronousWritesAllowedByDefault() + { + var firstRequest = true; + + using (var server = new TestServer(async context => + { + var bodyControlFeature = context.Features.Get(); + Assert.True(bodyControlFeature.AllowSynchronousIO); + + context.Response.ContentLength = 6; + + if (firstRequest) + { + context.Response.Body.Write(Encoding.ASCII.GetBytes("Hello1"), 0, 6); + firstRequest = false; + } + else + { + bodyControlFeature.AllowSynchronousIO = false; + + // Synchronous writes now throw. + var ioEx = Assert.Throws(() => context.Response.Body.Write(Encoding.ASCII.GetBytes("What!?"), 0, 6)); + Assert.Equal(CoreStrings.SynchronousWritesDisallowed, ioEx.Message); + + await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello2"), 0, 6); + } + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGet(); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 6", + "", + "Hello1"); + + await connection.SendEmptyGet(); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 6", + "", + "Hello2"); + } + } + } + + [Fact] + public async Task SynchronousWritesCanBeDisallowedGlobally() + { + var testContext = new TestServiceContext(LoggerFactory) + { + ServerOptions = { AllowSynchronousIO = false } + }; + + using (var server = new TestServer(context => + { + var bodyControlFeature = context.Features.Get(); + Assert.False(bodyControlFeature.AllowSynchronousIO); + + context.Response.ContentLength = 6; + + // Synchronous writes now throw. + var ioEx = Assert.Throws(() => context.Response.Body.Write(Encoding.ASCII.GetBytes("What!?"), 0, 6)); + Assert.Equal(CoreStrings.SynchronousWritesDisallowed, ioEx.Message); + + return context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello!"), 0, 6); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 6", + "", + "Hello!"); + } + } + } + + [Fact] + public async Task ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate() + { + using (StartLog(out var loggerFactory, "ConnClosedWhenRespDoesNotSatisfyMin")) + { + var logger = loggerFactory.CreateLogger($"{ typeof(ResponseTests).FullName}.{ nameof(ConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate)}"); + const int chunkSize = 1024; + const int chunks = 256 * 1024; + var responseSize = chunks * chunkSize; + var chunkData = new byte[chunkSize]; + + var responseRateTimeoutMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var connectionStopMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestAborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appFuncCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny())) + .Callback(() => responseRateTimeoutMessageLogged.SetResult(null)); + mockKestrelTrace + .Setup(trace => trace.ConnectionStop(It.IsAny())) + .Callback(() => connectionStopMessageLogged.SetResult(null)); + + var testContext = new TestServiceContext + { + LoggerFactory = loggerFactory, + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + listenOptions.ConnectionAdapters.Add(new LoggingConnectionAdapter(loggerFactory.CreateLogger())); + + var appLogger = loggerFactory.CreateLogger("App"); + async Task App(HttpContext context) + { + appLogger.LogInformation("Request received"); + context.RequestAborted.Register(() => requestAborted.SetResult(null)); + + context.Response.ContentLength = responseSize; + + try + { + for (var i = 0; i < chunks; i++) + { + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); + appLogger.LogInformation("Wrote chunk of {chunkSize} bytes", chunkSize); + } + } + catch (OperationCanceledException) + { + appFuncCompleted.SetResult(null); + throw; + } + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + logger.LogInformation("Sending request"); + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + logger.LogInformation("Sent request"); + + var sw = Stopwatch.StartNew(); + logger.LogInformation("Waiting for connection to abort."); + + await requestAborted.Task.DefaultTimeout(); + await responseRateTimeoutMessageLogged.Task.DefaultTimeout(); + await connectionStopMessageLogged.Task.DefaultTimeout(); + await appFuncCompleted.Task.DefaultTimeout(); + await AssertStreamAborted(connection.Reader.BaseStream, chunkSize * chunks); + + sw.Stop(); + logger.LogInformation("Connection was aborted after {totalMilliseconds}ms.", sw.ElapsedMilliseconds); + } + } + } + } + + [Fact] + public async Task HttpsConnectionClosedWhenResponseDoesNotSatisfyMinimumDataRate() + { + const int chunkSize = 1024; + const int chunks = 256 * 1024; + var chunkData = new byte[chunkSize]; + + var certificate = new X509Certificate2(TestResources.TestCertificatePath, "testPassword"); + + var responseRateTimeoutMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var connectionStopMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var aborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var appFuncCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny())) + .Callback(() => responseRateTimeoutMessageLogged.SetResult(null)); + mockKestrelTrace + .Setup(trace => trace.ConnectionStop(It.IsAny())) + .Callback(() => connectionStopMessageLogged.SetResult(null)); + + var testContext = new TestServiceContext(LoggerFactory, mockKestrelTrace.Object) + { + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) + { + ConnectionAdapters = + { + new HttpsConnectionAdapter(new HttpsConnectionAdapterOptions { ServerCertificate = certificate }) + } + }; + + using (var server = new TestServer(async context => + { + context.RequestAborted.Register(() => + { + aborted.SetResult(null); + }); + + context.Response.ContentLength = chunks * chunkSize; + + try + { + for (var i = 0; i < chunks; i++) + { + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); + } + } + catch (OperationCanceledException) + { + appFuncCompleted.SetResult(null); + throw; + } + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + using (var sslStream = new SslStream(connection.Reader.BaseStream, false, (sender, cert, chain, errors) => true, null)) + { + await sslStream.AuthenticateAsClientAsync("localhost", new X509CertificateCollection(), SslProtocols.Tls12 | SslProtocols.Tls11, false); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + + await aborted.Task.DefaultTimeout(); + await responseRateTimeoutMessageLogged.Task.DefaultTimeout(); + await connectionStopMessageLogged.Task.DefaultTimeout(); + await appFuncCompleted.Task.DefaultTimeout(); + + // Temporary workaround for a deadlock when reading from an aborted client SslStream on Mac and Linux. + if (TestPlatformHelper.IsWindows) + { + await AssertStreamAborted(sslStream, chunkSize * chunks); + } + else + { + await AssertStreamAborted(connection.Reader.BaseStream, chunkSize * chunks); + } + } + } + } + } + + [Fact] + public async Task ConnectionClosedWhenBothRequestAndResponseExperienceBackPressure() + { + const int bufferSize = 65536; + const int bufferCount = 100; + var responseSize = bufferCount * bufferSize; + var buffer = new byte[bufferSize]; + + var responseRateTimeoutMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var connectionStopMessageLogged = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestAborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var copyToAsyncCts = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny())) + .Callback(() => responseRateTimeoutMessageLogged.SetResult(null)); + mockKestrelTrace + .Setup(trace => trace.ConnectionStop(It.IsAny())) + .Callback(() => connectionStopMessageLogged.SetResult(null)); + + var testContext = new TestServiceContext + { + LoggerFactory = LoggerFactory, + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 1024 * 1024, gracePeriod: TimeSpan.FromSeconds(2)), + MaxRequestBodySize = responseSize + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + async Task App(HttpContext context) + { + context.RequestAborted.Register(() => + { + requestAborted.SetResult(null); + }); + + try + { + await context.Request.Body.CopyToAsync(context.Response.Body); + } + catch (Exception ex) + { + copyToAsyncCts.SetException(ex); + throw; + } + + copyToAsyncCts.SetException(new Exception("This shouldn't be reached.")); + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Close the connection with the last request so AssertStreamCompleted actually completes. + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {responseSize}", + "", + ""); + + var sendTask = Task.Run(async () => + { + for (var i = 0; i < bufferCount; i++) + { + await connection.Stream.WriteAsync(buffer, 0, buffer.Length); + await Task.Delay(10); + } + }); + + await requestAborted.Task.DefaultTimeout(); + await responseRateTimeoutMessageLogged.Task.DefaultTimeout(); + await connectionStopMessageLogged.Task.DefaultTimeout(); + + // Expect OperationCanceledException instead of IOException because the server initiated the abort due to a response rate timeout. + await Assert.ThrowsAnyAsync(() => copyToAsyncCts.Task).DefaultTimeout(); + await AssertStreamAborted(connection.Stream, responseSize); + } + } + } + + [Fact] + public async Task ConnectionNotClosedWhenClientSatisfiesMinimumDataRateGivenLargeResponseChunks() + { + var chunkSize = 64 * 128 * 1024; + var chunkCount = 4; + var chunkData = new byte[chunkSize]; + + var requestAborted = false; + var appFuncCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockKestrelTrace = new Mock(); + + var testContext = new TestServiceContext + { + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + async Task App(HttpContext context) + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + for (var i = 0; i < chunkCount; i++) + { + await context.Response.Body.WriteAsync(chunkData, 0, chunkData.Length, context.RequestAborted); + } + + appFuncCompleted.SetResult(null); + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + // Close the connection with the last request so AssertStreamCompleted actually completes. + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + var minTotalOutputSize = chunkCount * chunkSize; + + // Make sure consuming a single chunk exceeds the 2 second timeout. + var targetBytesPerSecond = chunkSize / 4; + await AssertStreamCompleted(connection.Reader.BaseStream, minTotalOutputSize, targetBytesPerSecond); + await appFuncCompleted.Task.DefaultTimeout(); + + mockKestrelTrace.Verify(t => t.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny()), Times.Never()); + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.AtMostOnce()); + Assert.False(requestAborted); + } + } + } + + [Fact] + public async Task ConnectionNotClosedWhenClientSatisfiesMinimumDataRateGivenLargeResponseHeaders() + { + var headerSize = 1024 * 1024; // 1 MB for each header value + var headerCount = 64; // 64 MB of headers per response + var requestCount = 4; // Minimum of 256 MB of total response headers + var headerValue = new string('a', headerSize); + var headerStringValues = new StringValues(Enumerable.Repeat(headerValue, headerCount).ToArray()); + + var requestAborted = false; + var mockKestrelTrace = new Mock(); + + var testContext = new TestServiceContext + { + Log = mockKestrelTrace.Object, + SystemClock = new SystemClock(), + ServerOptions = + { + Limits = + { + MinResponseDataRate = new MinDataRate(bytesPerSecond: 240, gracePeriod: TimeSpan.FromSeconds(2)) + } + } + }; + + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); + + async Task App(HttpContext context) + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + context.Response.Headers[$"X-Custom-Header"] = headerStringValues; + context.Response.ContentLength = 0; + + await context.Response.Body.FlushAsync(); + } + + using (var server = new TestServer(App, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + for (var i = 0; i < requestCount - 1; i++) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + } + + // Close the connection with the last request so AssertStreamCompleted actually completes. + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Connection: close", + "", + ""); + + var responseSize = headerSize * headerCount; + var minTotalOutputSize = requestCount * responseSize; + + // Make sure consuming a single set of response headers exceeds the 2 second timeout. + var targetBytesPerSecond = responseSize / 4; + await AssertStreamCompleted(connection.Reader.BaseStream, minTotalOutputSize, targetBytesPerSecond); + + mockKestrelTrace.Verify(t => t.ResponseMininumDataRateNotSatisfied(It.IsAny(), It.IsAny()), Times.Never()); + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.AtMostOnce()); + Assert.False(requestAborted); + } + } + } + + [Fact] + public async Task NonZeroContentLengthFor304StatusCodeIsAllowed() + { + using (var server = new TestServer(httpContext => + { + var response = httpContext.Response; + response.StatusCode = StatusCodes.Status304NotModified; + response.ContentLength = 42; + + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 304 Not Modified", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 42", + "", + ""); + } + } + } + + private async Task AssertStreamAborted(Stream stream, int totalBytes) + { + var receiveBuffer = new byte[64 * 1024]; + var totalReceived = 0; + + try + { + while (totalReceived < totalBytes) + { + var bytes = await stream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length).DefaultTimeout(); + + if (bytes == 0) + { + break; + } + + totalReceived += bytes; + } + } + catch (IOException) + { + // This is expected given an abort. + } + + Assert.True(totalReceived < totalBytes, $"{nameof(AssertStreamAborted)} Stream completed successfully."); + } + + private async Task AssertStreamCompleted(Stream stream, long minimumBytes, int targetBytesPerSecond) + { + var receiveBuffer = new byte[64 * 1024]; + var received = 0; + var totalReceived = 0; + var startTime = DateTimeOffset.UtcNow; + + do + { + received = await stream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length); + totalReceived += received; + + var expectedTimeElapsed = TimeSpan.FromSeconds(totalReceived / targetBytesPerSecond); + var timeElapsed = DateTimeOffset.UtcNow - startTime; + if (timeElapsed < expectedTimeElapsed) + { + await Task.Delay(expectedTimeElapsed - timeElapsed); + } + } while (received > 0); + + Assert.True(totalReceived >= minimumBytes, $"{nameof(AssertStreamCompleted)} Stream aborted prematurely."); + } + + public static TheoryData NullHeaderData + { + get + { + var dataset = new TheoryData(); + + // Unknown headers + dataset.Add("NullString", (string)null, null); + dataset.Add("EmptyString", "", ""); + dataset.Add("NullStringArray", new string[] { null }, null); + dataset.Add("EmptyStringArray", new string[] { "" }, ""); + dataset.Add("MixedStringArray", new string[] { null, "" }, ""); + // Known headers + dataset.Add("Location", (string)null, null); + dataset.Add("Location", "", ""); + dataset.Add("Location", new string[] { null }, null); + dataset.Add("Location", new string[] { "" }, ""); + dataset.Add("Location", new string[] { null, "" }, ""); + + return dataset; + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/HostNameIsReachableAttribute.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/HostNameIsReachableAttribute.cs new file mode 100644 index 0000000000..5a83bd7c3f --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/HostNameIsReachableAttribute.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Testing.xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class HostNameIsReachableAttribute : Attribute, ITestCondition + { + private string _hostname; + private string _error; + private bool? _isMet; + + public bool IsMet + { + get + { + return _isMet ?? (_isMet = HostNameIsReachable().GetAwaiter().GetResult()).Value; + } + } + + public string SkipReason => _hostname != null + ? $"Test cannot run when network is unreachable. Socket exception: '{_error}'" + : "Could not determine hostname for current test machine"; + + private async Task HostNameIsReachable() + { + try + { + _hostname = Dns.GetHostName(); + + // if the network is unreachable on macOS, throws with SocketError.NetworkUnreachable + // if the network device is not configured, throws with SocketError.HostNotFound + // if the network is reachable, throws with SocketError.ConnectionRefused or succeeds + var timeoutTask = Task.Delay(1000); + if (await Task.WhenAny(ConnectToHost(_hostname, 80), timeoutTask) == timeoutTask) + { + _error = "Attempt to establish a connection took over a second without success or failure."; + return false; + } + } + catch (SocketException ex) when ( + ex.SocketErrorCode == SocketError.NetworkUnreachable + || ex.SocketErrorCode == SocketError.HostNotFound) + { + _error = ex.Message; + return false; + } + catch + { + // Swallow other errors. Allows the test to throw the failures instead + } + + return true; + } + + public static async Task ConnectToHost(string hostName, int port) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var socketArgs = new SocketAsyncEventArgs(); + socketArgs.RemoteEndPoint = new DnsEndPoint(hostName, port); + socketArgs.Completed += (s, e) => tcs.TrySetResult(e.ConnectSocket); + + // Must use static ConnectAsync(), since instance Connect() does not support DNS names on OSX/Linux. + if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, socketArgs)) + { + await tcs.Task.ConfigureAwait(false); + } + + var socket = socketArgs.ConnectSocket; + + if (socket == null) + { + throw new SocketException((int)socketArgs.SocketError); + } + else + { + return socket; + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6ScopeIdPresentConditionAttribute.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6ScopeIdPresentConditionAttribute.cs new file mode 100644 index 0000000000..f1705b9c7b --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6ScopeIdPresentConditionAttribute.cs @@ -0,0 +1,35 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Net.NetworkInformation; +using System.Net.Sockets; +using Microsoft.AspNetCore.Testing.xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class IPv6ScopeIdPresentConditionAttribute : Attribute, ITestCondition + { + private static readonly Lazy _ipv6ScopeIdPresent = new Lazy(IPv6ScopeIdAddressPresent); + + public bool IsMet => _ipv6ScopeIdPresent.Value; + + public string SkipReason => "No IPv6 addresses with scope IDs were found on the host."; + + private static bool IPv6ScopeIdAddressPresent() + { + try + { + return NetworkInterface.GetAllNetworkInterfaces() + .Where(iface => iface.OperationalStatus == OperationalStatus.Up) + .SelectMany(iface => iface.GetIPProperties().UnicastAddresses) + .Any(addrInfo => addrInfo.Address.AddressFamily == AddressFamily.InterNetworkV6 && addrInfo.Address.ScopeId != 0); + } + catch (SocketException) + { + return false; + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6SupportedConditionAttribute.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6SupportedConditionAttribute.cs new file mode 100644 index 0000000000..815a271825 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IPv6SupportedConditionAttribute.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Net.Sockets; +using Microsoft.AspNetCore.Testing.xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class IPv6SupportedConditionAttribute : Attribute, ITestCondition + { + private static readonly Lazy _ipv6Supported = new Lazy(CanBindToIPv6Address); + + public bool IsMet => _ipv6Supported.Value; + + public string SkipReason => "IPv6 not supported on the host."; + + private static bool CanBindToIPv6Address() + { + try + { + using (var socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Bind(new IPEndPoint(IPAddress.IPv6Loopback, 0)); + return true; + } + } + catch (SocketException) + { + return false; + } + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IWebHostPortExtensions.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IWebHostPortExtensions.cs new file mode 100644 index 0000000000..2f3ae47c24 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/IWebHostPortExtensions.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Hosting.Server.Features; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class IWebHostPortExtensions + { + public static int GetPort(this IWebHost host) + { + return host.GetPorts().First(); + } + + public static IEnumerable GetPorts(this IWebHost host) + { + return host.GetUris() + .Select(u => u.Port); + } + + public static IEnumerable GetUris(this IWebHost host) + { + return host.ServerFeatures.Get().Addresses + .Select(a => new Uri(a)); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestApplicationErrorLoggerLoggedTest.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestApplicationErrorLoggerLoggedTest.cs new file mode 100644 index 0000000000..1dea888411 --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestApplicationErrorLoggerLoggedTest.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.Extensions.Logging.Testing; + +namespace Microsoft.AspNetCore.Testing +{ + public class TestApplicationErrorLoggerLoggedTest : LoggedTest + { + public TestApplicationErrorLogger TestApplicationErrorLogger { get; private set; } + + public override void AdditionalSetup() + { + TestApplicationErrorLogger = new TestApplicationErrorLogger(); + LoggerFactory.AddProvider(new KestrelTestLoggerProvider(TestApplicationErrorLogger)); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestServer.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestServer.cs new file mode 100644 index 0000000000..7a54c418fa --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/TestServer.cs @@ -0,0 +1,128 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + /// + /// Summary description for TestServer + /// + public class TestServer : IDisposable, IStartup + { + private IWebHost _host; + private ListenOptions _listenOptions; + private readonly RequestDelegate _app; + + public TestServer(RequestDelegate app) + : this(app, new TestServiceContext()) + { + } + + public TestServer(RequestDelegate app, TestServiceContext context) + : this(app, context, new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))) + { + } + + public TestServer(RequestDelegate app, TestServiceContext context, ListenOptions listenOptions) + : this(app, context, listenOptions, _ => { }) + { + } + + public TestServer(RequestDelegate app, TestServiceContext context, ListenOptions listenOptions, Action configureServices) + { + _app = app; + _listenOptions = listenOptions; + Context = context; + + _host = TransportSelector.GetWebHostBuilder() + .UseKestrel(o => + { + o.ListenOptions.Add(_listenOptions); + }) + .ConfigureServices(services => + { + services.AddSingleton(this); + services.AddSingleton(context.LoggerFactory); + services.AddSingleton(sp => + { + // Manually configure options on the TestServiceContext. + // We're doing this so we can use the same instance that was passed in + var configureOptions = sp.GetServices>(); + foreach (var c in configureOptions) + { + c.Configure(context.ServerOptions); + } + return new KestrelServer(sp.GetRequiredService(), context); + }); + RemoveDevCert(services); + configureServices(services); + }) + .UseSetting(WebHostDefaults.ApplicationKey, typeof(TestServer).GetTypeInfo().Assembly.FullName) + .Build(); + + _host.Start(); + } + + public static void RemoveDevCert(IServiceCollection services) + { + // KestrelServerOptionsSetup would scan all system certificates on every test server creation + // making test runs very slow + foreach (var descriptor in services.ToArray()) + { + if (descriptor.ImplementationType == typeof(KestrelServerOptionsSetup)) + { + services.Remove(descriptor); + } + } + } + + public IPEndPoint EndPoint => _listenOptions.IPEndPoint; + public int Port => _listenOptions.IPEndPoint.Port; + public AddressFamily AddressFamily => _listenOptions.IPEndPoint.AddressFamily; + + public TestServiceContext Context { get; } + + void IStartup.Configure(IApplicationBuilder app) + { + app.Run(_app); + } + + IServiceProvider IStartup.ConfigureServices(IServiceCollection services) + { + // Unfortunately, this needs to be replaced in IStartup.ConfigureServices + services.AddSingleton(); + return services.BuildServiceProvider(); + } + + public TestConnection CreateConnection() + { + return new TestConnection(Port, AddressFamily); + } + + public Task StopAsync() + { + return _host.StopAsync(); + } + + public void Dispose() + { + _host.Dispose(); + } + } +} diff --git a/src/Servers/Kestrel/test/FunctionalTests/UpgradeTests.cs b/src/Servers/Kestrel/test/FunctionalTests/UpgradeTests.cs new file mode 100644 index 0000000000..2b4027d37c --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/UpgradeTests.cs @@ -0,0 +1,338 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Tests; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class UpgradeTests : LoggedTest + { + [Fact] + public async Task ResponseThrowsAfterUpgrade() + { + var upgrade = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + var stream = await feature.UpgradeAsync(); + + var ex = await Assert.ThrowsAsync(() => context.Response.Body.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, ex.Message); + + using (var writer = new StreamWriter(stream)) + { + await writer.WriteLineAsync("New protocol data"); + await writer.FlushAsync(); + } + + upgrade.TrySetResult(true); + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + + await connection.Receive("New protocol data"); + await upgrade.Task.DefaultTimeout(); + } + } + } + + [Fact] + public async Task RequestBodyAlwaysEmptyAfterUpgrade() + { + const string send = "Custom protocol send"; + const string recv = "Custom protocol recv"; + + var upgrade = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (var server = new TestServer(async context => + { + try + { + var feature = context.Features.Get(); + var stream = await feature.UpgradeAsync(); + + var buffer = new byte[128]; + var read = await context.Request.Body.ReadAsync(buffer, 0, 128).DefaultTimeout(); + Assert.Equal(0, read); + + using (var reader = new StreamReader(stream)) + using (var writer = new StreamWriter(stream)) + { + var line = await reader.ReadLineAsync(); + Assert.Equal(send, line); + await writer.WriteLineAsync(recv); + await writer.FlushAsync(); + } + + upgrade.TrySetResult(true); + } + catch (Exception ex) + { + upgrade.SetException(ex); + throw; + } + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + + await connection.Send(send + "\r\n"); + await connection.Receive(recv); + + await upgrade.Task.DefaultTimeout(); + } + } + } + + [Fact] + public async Task UpgradeCannotBeCalledMultipleTimes() + { + var upgradeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + await feature.UpgradeAsync(); + + try + { + await feature.UpgradeAsync(); + } + catch (Exception e) + { + upgradeTcs.TrySetException(e); + throw; + } + + while (!context.RequestAborted.IsCancellationRequested) + { + await Task.Delay(100); + } + }, new TestServiceContext(LoggerFactory))) + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + await connection.WaitForConnectionClose().DefaultTimeout(); + } + + var ex = await Assert.ThrowsAsync(async () => await upgradeTcs.Task.DefaultTimeout()); + Assert.Equal(CoreStrings.UpgradeCannotBeCalledMultipleTimes, ex.Message); + } + + [Fact] + public async Task RejectsRequestWithContentLengthAndUpgrade() + { + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "Connection: Upgrade", + "", + ""); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + [Fact] + public async Task AcceptsRequestWithNoContentLengthAndUpgrade() + { + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Content-Length: 0", + "Connection: Upgrade, keep-alive", + "", + ""); + await connection.Receive("HTTP/1.1 200 OK"); + } + + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 200 OK"); + } + } + } + + [Fact] + public async Task RejectsRequestWithChunkedEncodingAndUpgrade() + { + using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory))) + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "Connection: Upgrade", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + [Fact] + public async Task ThrowsWhenUpgradingNonUpgradableRequest() + { + var upgradeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + Assert.False(feature.IsUpgradableRequest); + try + { + var stream = await feature.UpgradeAsync(); + } + catch (Exception e) + { + upgradeTcs.TrySetException(e); + } + finally + { + upgradeTcs.TrySetResult(false); + } + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGet(); + await connection.Receive("HTTP/1.1 200 OK"); + } + } + + var ex = await Assert.ThrowsAsync(async () => await upgradeTcs.Task).DefaultTimeout(); + Assert.Equal(CoreStrings.CannotUpgradeNonUpgradableRequest, ex.Message); + } + + [Fact] + public async Task RejectsUpgradeWhenLimitReached() + { + const int limit = 10; + var upgradeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serviceContext = new TestServiceContext(LoggerFactory); + serviceContext.ConnectionManager = new HttpConnectionManager(serviceContext.Log, ResourceCounter.Quota(limit)); + + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + if (feature.IsUpgradableRequest) + { + try + { + var stream = await feature.UpgradeAsync(); + while (!context.RequestAborted.IsCancellationRequested) + { + await Task.Delay(100); + } + } + catch (InvalidOperationException ex) + { + upgradeTcs.TrySetException(ex); + } + } + }, serviceContext)) + { + using (var disposables = new DisposableStack()) + { + for (var i = 0; i < limit; i++) + { + var connection = server.CreateConnection(); + disposables.Push(connection); + + await connection.SendEmptyGetWithUpgradeAndKeepAlive(); + await connection.Receive("HTTP/1.1 101"); + } + + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgradeAndKeepAlive(); + await connection.Receive("HTTP/1.1 200"); + } + } + } + + var exception = await Assert.ThrowsAsync(async () => await upgradeTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(60))); + Assert.Equal(CoreStrings.UpgradedConnectionLimitReached, exception.Message); + } + + [Fact] + public async Task DoesNotThrowOnFin() + { + var appCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + var duplexStream = await feature.UpgradeAsync(); + + try + { + await duplexStream.CopyToAsync(Stream.Null); + appCompletedTcs.SetResult(null); + } + catch (Exception ex) + { + appCompletedTcs.SetException(ex); + throw; + } + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + } + + await appCompletedTcs.Task.DefaultTimeout(); + } + } + } +} diff --git a/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj b/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj new file mode 100644 index 0000000000..a0a9254fbf --- /dev/null +++ b/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj @@ -0,0 +1,30 @@ + + + + netcoreapp2.1;net461 + $(DefineConstants);MACOS + true + + Libuv.FunctionalTests + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/test/Libuv.FunctionalTests/ListenHandleTests.cs b/src/Servers/Kestrel/test/Libuv.FunctionalTests/ListenHandleTests.cs new file mode 100644 index 0000000000..61a24966d0 --- /dev/null +++ b/src/Servers/Kestrel/test/Libuv.FunctionalTests/ListenHandleTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging.Testing; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + [OSSkipCondition(OperatingSystems.Windows, SkipReason = "Listening to open TCP socket and/or pipe handles is not supported on Windows.")] + public class ListenHandleTests : LoggedTest + { + [ConditionalFact] + public async Task CanListenToOpenTcpSocketHandle() + { + using (var listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listenSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + + using (var server = new TestServer(_ => Task.CompletedTask, new TestServiceContext(LoggerFactory), new ListenOptions((ulong)listenSocket.Handle))) + { + using (var connection = new TestConnection(((IPEndPoint)listenSocket.LocalEndPoint).Port)) + { + await connection.SendEmptyGet(); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + } + } +} diff --git a/src/Servers/Kestrel/test/Libuv.FunctionalTests/TransportSelector.cs b/src/Servers/Kestrel/test/Libuv.FunctionalTests/TransportSelector.cs new file mode 100644 index 0000000000..db778d603b --- /dev/null +++ b/src/Servers/Kestrel/test/Libuv.FunctionalTests/TransportSelector.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Hosting; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public static class TransportSelector + { + public static IWebHostBuilder GetWebHostBuilder() + { + return new WebHostBuilder().UseLibuv().ConfigureServices(TestServer.RemoveDevCert); + } + } +} diff --git a/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj b/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj new file mode 100644 index 0000000000..052fd78659 --- /dev/null +++ b/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj @@ -0,0 +1,29 @@ + + + + netcoreapp2.1;net461 + $(DefineConstants);MACOS + $(DefineConstants);SOCKETS + true + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Servers/Kestrel/test/Sockets.FunctionalTests/TransportSelector.cs b/src/Servers/Kestrel/test/Sockets.FunctionalTests/TransportSelector.cs new file mode 100644 index 0000000000..52d2b0a4a8 --- /dev/null +++ b/src/Servers/Kestrel/test/Sockets.FunctionalTests/TransportSelector.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Hosting; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public static class TransportSelector + { + public static IWebHostBuilder GetWebHostBuilder() + { + return new WebHostBuilder().UseSockets().ConfigureServices(TestServer.RemoveDevCert); + } + } +} diff --git a/src/Servers/Kestrel/test/SystemdActivation/Dockerfile b/src/Servers/Kestrel/test/SystemdActivation/Dockerfile new file mode 100644 index 0000000000..5aee5312dd --- /dev/null +++ b/src/Servers/Kestrel/test/SystemdActivation/Dockerfile @@ -0,0 +1,23 @@ +FROM microsoft/dotnet-nightly:2.0-runtime-deps + +# The "container" environment variable is read by systemd. +ENV container=docker + +# Install and configure dependencies. +RUN ["apt-get", "-o", "Acquire::Check-Valid-Until=false", "update"] +RUN ["apt-get", "install", "-y", "--no-install-recommends", "systemd", "socat"] + +# Copy .NET installation. +ADD .dotnet/ /usr/share/dotnet/ +RUN ["ln", "-s", "/usr/share/dotnet/dotnet", "/usr/bin/dotnet"] + +# Copy "publish" app. +ADD publish/ /publish/ + +# Expose target ports. +EXPOSE 8080 8081 8082 8083 8084 8085 + +# Set entrypoint. +COPY ./docker-entrypoint.sh / +RUN chmod +x /docker-entrypoint.sh +ENTRYPOINT ["/docker-entrypoint.sh"] diff --git a/src/Servers/Kestrel/test/SystemdActivation/docker-entrypoint.sh b/src/Servers/Kestrel/test/SystemdActivation/docker-entrypoint.sh new file mode 100644 index 0000000000..cb8d2f2d6f --- /dev/null +++ b/src/Servers/Kestrel/test/SystemdActivation/docker-entrypoint.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -e + +cd /publish +systemd-socket-activate -l 8080 -E ASPNETCORE_BASE_PORT=7000 dotnet SystemdTestApp.dll & +socat TCP-LISTEN:8081,fork TCP-CONNECT:127.0.0.1:7000 & +socat TCP-LISTEN:8082,fork TCP-CONNECT:127.0.0.1:7001 & +systemd-socket-activate -l /tmp/activate-kestrel.sock -E ASPNETCORE_BASE_PORT=7100 dotnet SystemdTestApp.dll & +socat TCP-LISTEN:8083,fork UNIX-CLIENT:/tmp/activate-kestrel.sock & +socat TCP-LISTEN:8084,fork TCP-CONNECT:127.0.0.1:7100 & +socat TCP-LISTEN:8085,fork TCP-CONNECT:127.0.0.1:7101 & +trap 'exit 0' SIGTERM +wait diff --git a/src/Servers/Kestrel/test/SystemdActivation/docker.sh b/src/Servers/Kestrel/test/SystemdActivation/docker.sh new file mode 100644 index 0000000000..f2ba5fe506 --- /dev/null +++ b/src/Servers/Kestrel/test/SystemdActivation/docker.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -e + +scriptDir=$(dirname "${BASH_SOURCE[0]}") +PATH="$PWD/.dotnet/:$PATH" +dotnet publish -f netcoreapp2.0 ./samples/SystemdTestApp/ +cp -R ./samples/SystemdTestApp/bin/Debug/netcoreapp2.0/publish/ $scriptDir +cp -R ./.dotnet/ $scriptDir + +image=$(docker build -qf $scriptDir/Dockerfile $scriptDir) +container=$(docker run -Pd $image) + +# Try to connect to SystemdTestApp once a second up to 10 times via all available ports. +for i in {1..10}; do + curl -f http://$(docker port $container 8080/tcp) \ + && curl -f http://$(docker port $container 8081/tcp) \ + && curl -fk https://$(docker port $container 8082/tcp) \ + && curl -f http://$(docker port $container 8083/tcp) \ + && curl -f http://$(docker port $container 8084/tcp) \ + && curl -fk https://$(docker port $container 8085/tcp) \ + && exit 0 || sleep 1; +done + +exit -1 diff --git a/src/Servers/Kestrel/tools/CodeGenerator/CodeGenerator.csproj b/src/Servers/Kestrel/tools/CodeGenerator/CodeGenerator.csproj new file mode 100644 index 0000000000..3a209fea5b --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/CodeGenerator.csproj @@ -0,0 +1,22 @@ + + + + netcoreapp2.0 + Exe + false + true + false + + + + + + + + + + $(MSBuildThisFileDirectory)..\..\Core\src + Internal/Http/HttpHeaders.Generated.cs Internal/Http/HttpProtocol.Generated.cs Internal/Infrastructure/HttpUtilities.Generated.cs + + + diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs new file mode 100644 index 0000000000..db6508080e --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs @@ -0,0 +1,199 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace CodeGenerator +{ + // This project can output the Class library as a NuGet Package. + // To enable this option, right-click on the project and select the Properties menu item. In the Build tab select "Produce outputs on build". + public class HttpProtocolFeatureCollection + { + static string Each(IEnumerable values, Func formatter) + { + return values.Select(formatter).Aggregate((a, b) => a + b); + } + + public class KnownFeature + { + public string Name; + public int Index; + } + + public static string GeneratedFile(string className) + { + var alwaysFeatures = new[] + { + "IHttpRequestFeature", + "IHttpResponseFeature", + "IHttpRequestIdentifierFeature", + "IServiceProvidersFeature", + "IHttpRequestLifetimeFeature", + "IHttpConnectionFeature", + }; + + var commonFeatures = new[] + { + "IHttpAuthenticationFeature", + "IQueryFeature", + "IFormFeature", + }; + + var sometimesFeatures = new[] + { + "IHttpUpgradeFeature", + "IHttp2StreamIdFeature", + "IResponseCookiesFeature", + "IItemsFeature", + "ITlsConnectionFeature", + "IHttpWebSocketFeature", + "ISessionFeature", + "IHttpMaxRequestBodySizeFeature", + "IHttpMinRequestBodyDataRateFeature", + "IHttpMinResponseDataRateFeature", + "IHttpBodyControlFeature", + }; + + var rareFeatures = new[] + { + "IHttpSendFileFeature", + }; + + var allFeatures = alwaysFeatures + .Concat(commonFeatures) + .Concat(sometimesFeatures) + .Concat(rareFeatures) + .Select((type, index) => new KnownFeature + { + Name = type, + Index = index + }); + + // NOTE: This list MUST always match the set of feature interfaces implemented by HttpProtocol. + // See also: src/Kestrel/Http/HttpProtocol.FeatureCollection.cs + var implementedFeatures = new[] + { + "IHttpRequestFeature", + "IHttpResponseFeature", + "IHttpRequestIdentifierFeature", + "IHttpRequestLifetimeFeature", + "IHttpConnectionFeature", + "IHttpMaxRequestBodySizeFeature", + "IHttpMinRequestBodyDataRateFeature", + "IHttpMinResponseDataRateFeature", + "IHttpBodyControlFeature", + }; + + return $@"// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Features.Authentication; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{{ + public partial class {className} + {{{Each(allFeatures, feature => $@" + private static readonly Type {feature.Name}Type = typeof({feature.Name});")} +{Each(allFeatures, feature => $@" + private object _current{feature.Name};")} + + private void FastReset() + {{{Each(implementedFeatures, feature => $@" + _current{feature} = this;")} + {Each(allFeatures.Where(f => !implementedFeatures.Contains(f.Name)), feature => $@" + _current{feature.Name} = null;")} + }} + + object IFeatureCollection.this[Type key] + {{ + get + {{ + object feature = null;{Each(allFeatures, feature => $@" + {(feature.Index != 0 ? "else " : "")}if (key == {feature.Name}Type) + {{ + feature = _current{feature.Name}; + }}")} + else if (MaybeExtra != null) + {{ + feature = ExtraFeatureGet(key); + }} + + return feature ?? ConnectionFeatures[key]; + }} + + set + {{ + _featureRevision++; + {Each(allFeatures, feature => $@" + {(feature.Index != 0 ? "else " : "")}if (key == {feature.Name}Type) + {{ + _current{feature.Name} = value; + }}")} + else + {{ + ExtraFeatureSet(key, value); + }} + }} + }} + + void IFeatureCollection.Set(TFeature feature) + {{ + _featureRevision++;{Each(allFeatures, feature => $@" + {(feature.Index != 0 ? "else " : "")}if (typeof(TFeature) == typeof({feature.Name})) + {{ + _current{feature.Name} = feature; + }}")} + else + {{ + ExtraFeatureSet(typeof(TFeature), feature); + }} + }} + + TFeature IFeatureCollection.Get() + {{ + TFeature feature = default;{Each(allFeatures, feature => $@" + {(feature.Index != 0 ? "else " : "")}if (typeof(TFeature) == typeof({feature.Name})) + {{ + feature = (TFeature)_current{feature.Name}; + }}")} + else if (MaybeExtra != null) + {{ + feature = (TFeature)(ExtraFeatureGet(typeof(TFeature))); + }} + + if (feature == null) + {{ + feature = ConnectionFeatures.Get(); + }} + + return feature; + }} + + private IEnumerable> FastEnumerable() + {{{Each(allFeatures, feature => $@" + if (_current{feature.Name} != null) + {{ + yield return new KeyValuePair({feature.Name}Type, _current{feature.Name} as {feature.Name}); + }}")} + + if (MaybeExtra != null) + {{ + foreach(var item in MaybeExtra) + {{ + yield return item; + }} + }} + }} + }} +}} +"; + } + } +} diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/CombinationsWithoutRepetition.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/CombinationsWithoutRepetition.cs new file mode 100644 index 0000000000..65a6b5a895 --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/CombinationsWithoutRepetition.cs @@ -0,0 +1,103 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace CodeGenerator.HttpUtilities +{ + // C code for Algorithm L (Lexicographic combinations) in Section 7.2.1.3 of The Art of Computer Programming, Volume 4A: Combinatorial Algorithms, Part 1 : + internal class CombinationsWithoutRepetition : IEnumerator + { + private bool _firstElement; + private int[] _pointers; + private T[] _nElements; + private readonly int _p; + + public CombinationsWithoutRepetition(T[] nElements, int p) + { + if (nElements.Length < p) throw new ArgumentOutOfRangeException(nameof(p)); + + _nElements = nElements; + _p = p; + Current = new T[p]; + ResetCurrent(); + } + + public T[] Current { get; private set; } + object IEnumerator.Current => Current; + + public bool MoveNext() + { + if (_firstElement) + { + _firstElement = false; + return true; + } + + var p = _p; + var pointers = _pointers; + var current = Current; + var nElements = _nElements; + var index = 1; + + while (pointers[index] + 1 == pointers[index + 1]) + { + var j1 = index - 1; + + pointers[index] = j1; + current[j1] = nElements[j1]; + ++index; + } + + if (index > p) + { + return false; + } + + current[index - 1] = nElements[++pointers[index]]; + + return true; + } + + private void ResetCurrent() + { + var p = _p; + if (_pointers == null) + _pointers = new int[p + 3]; + + var pointers = _pointers; + var current = Current; + var nElements = _nElements; + + pointers[0] = 0; + for (int j = 1; j <= _p; j++) + { + pointers[j] = j - 1; + } + pointers[_p + 1] = nElements.Length; + pointers[_p + 2] = 0; + + for (int j = _p; j > 0; j--) + { + current[j - 1] = nElements[pointers[j]]; + } + _firstElement = true; + } + + public void Reset() + { + Array.Clear(Current, 0, Current.Length); + Current = null; + ResetCurrent(); + } + + public void Dispose() + { + _nElements = null; + Current = null; + _pointers = null; + } + } +} diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilities.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilities.cs new file mode 100644 index 0000000000..0c70c0c188 --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilities.cs @@ -0,0 +1,321 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Http; + +namespace CodeGenerator.HttpUtilities +{ + public class HttpUtilities + { + public static string GeneratedFile() + { + var httpMethods = new [] + { + new Tuple("CONNECT ", "Connect"), + new Tuple("DELETE ", "Delete"), + new Tuple("HEAD ", "Head"), + new Tuple("PATCH ", "Patch"), + new Tuple("POST ", "Post"), + new Tuple("PUT ", "Put"), + new Tuple("OPTIONS ", "Options"), + new Tuple("TRACE ", "Trace"), + new Tuple("GET ", "Get") + }; + + return GenerateFile(httpMethods); + } + + private static string GenerateFile(Tuple[] httpMethods) + { + var maskLength = (byte)Math.Ceiling(Math.Log(httpMethods.Length, 2)); + + var methodsInfo = httpMethods.Select(GetMethodStringAndUlongAndMaskLength).ToList(); + + var methodsInfoWithoutGet = methodsInfo.Where(m => m.HttpMethod != "Get".ToString()).ToList(); + + var methodsAsciiStringAsLong = methodsInfo.Select(m => m.AsciiStringAsLong).ToArray(); + + var mask = HttpUtilitiesGeneratorHelpers.SearchKeyByLookThroughMaskCombinations(methodsAsciiStringAsLong, 0, sizeof(ulong) * 8, maskLength); + + if (mask.HasValue == false) + { + throw new InvalidOperationException(string.Format("Generated {0} not found.", nameof(mask))); + } + + var functionGetKnownMethodIndex = GetFunctionBodyGetKnownMethodIndex(mask.Value); + + var methodsSection = GetMethodsSection(methodsInfoWithoutGet); + + var masksSection = GetMasksSection(methodsInfoWithoutGet); + + var setKnownMethodSection = GetSetKnownMethodSection(methodsInfoWithoutGet); + var methodNamesSection = GetMethodNamesSection(methodsInfo); + + int knownMethodsArrayLength = (int)(Math.Pow(2, maskLength) + 1); + int methodNamesArrayLength = httpMethods.Length; + + return string.Format(@"// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{{ + public static partial class HttpUtilities + {{ + // readonly primitive statics can be Jit'd to consts https://github.com/dotnet/coreclr/issues/1079 +{0} + +{1} + private static readonly Tuple[] _knownMethods = + new Tuple[{2}]; + + private static readonly string[] _methodNames = new string[{3}]; + + static HttpUtilities() + {{ +{4} + FillKnownMethodsGaps(); + InitializeHostCharValidity(); +{5} + }} + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetKnownMethodIndex(ulong value) + {{ +{6} + }} + }} +}}", methodsSection, masksSection, knownMethodsArrayLength, methodNamesArrayLength, setKnownMethodSection, methodNamesSection, functionGetKnownMethodIndex); + } + + private static string GetMethodsSection(List methodsInfo) + { + var result = new StringBuilder(); + + for (var index = 0; index < methodsInfo.Count; index++) + { + var methodInfo = methodsInfo[index]; + + var httpMethodFieldName = GetHttpMethodFieldName(methodInfo); + result.AppendFormat(" private static readonly ulong {0} = GetAsciiStringAsLong(\"{1}\");", httpMethodFieldName, methodInfo.MethodAsciiString.Replace("\0", "\\0")); + + if (index < methodsInfo.Count - 1) + { + result.AppendLine(); + } + } + + return result.ToString(); + } + + private static string GetMasksSection(List methodsInfo) + { + var distinctLengths = methodsInfo.Select(m => m.MaskLength).Distinct().ToList(); + + distinctLengths.Sort((t1, t2) => -t1.CompareTo(t2)); + + var result = new StringBuilder(); + + for (var index = 0; index < distinctLengths.Count; index++) + { + var maskBytesLength = distinctLengths[index]; + var maskArray = GetMaskArray(maskBytesLength); + + var hexMaskString = HttpUtilitiesGeneratorHelpers.GeHexString(maskArray, "0x", ", "); + var maskFieldName = GetMaskFieldName(maskBytesLength); + + result.AppendFormat(" private static readonly ulong {0} = GetMaskAsLong(new byte[]\r\n {{{1}}});", maskFieldName, hexMaskString); + result.AppendLine(); + if (index < distinctLengths.Count - 1) + { + result.AppendLine(); + } + } + + return result.ToString(); + } + + private static string GetSetKnownMethodSection(List methodsInfo) + { + methodsInfo = methodsInfo.ToList(); + + methodsInfo.Sort((t1, t2) => t1.MaskLength.CompareTo(t2.MaskLength)); + + var result = new StringBuilder(); + + for (var index = 0; index < methodsInfo.Count; index++) + { + var methodInfo = methodsInfo[index]; + var maskFieldName = GetMaskFieldName(methodInfo.MaskLength); + var httpMethodFieldName = GetHttpMethodFieldName(methodInfo); + + result.AppendFormat(" SetKnownMethod({0}, {1}, HttpMethod.{3}, {4});", maskFieldName, httpMethodFieldName, typeof(String).Name, methodInfo.HttpMethod, methodInfo.MaskLength - 1); + + if (index < methodsInfo.Count - 1) + { + result.AppendLine(); + } + } + + return result.ToString(); + } + + private static string GetMethodNamesSection(List methodsInfo) + { + methodsInfo = methodsInfo.ToList(); + + methodsInfo.Sort((t1, t2) => t1.HttpMethod.CompareTo(t2.HttpMethod)); + + var result = new StringBuilder(); + + for (var index = 0; index < methodsInfo.Count; index++) + { + var methodInfo = methodsInfo[index]; + + result.AppendFormat(" _methodNames[(byte)HttpMethod.{1}] = {2}.{3};", typeof(String).Name, methodInfo.HttpMethod, typeof(HttpMethods).Name, methodInfo.HttpMethod); + + if (index < methodsInfo.Count - 1) + { + result.AppendLine(); + } + } + + return result.ToString(); + } + + private static string GetFunctionBodyGetKnownMethodIndex(ulong mask) + { + var shifts = HttpUtilitiesGeneratorHelpers.GetShifts(mask); + + var maskHexString = HttpUtilitiesGeneratorHelpers.MaskToHexString(mask); + + string bodyString; + + if (shifts.Length > 0) + { + var bitsCount = HttpUtilitiesGeneratorHelpers.CountBits(mask); + + var tmpReturn = string.Empty; + foreach (var item in shifts) + { + if (tmpReturn.Length > 0) + { + tmpReturn += " | "; + } + + tmpReturn += string.Format("(tmp >> {1})", HttpUtilitiesGeneratorHelpers.MaskToHexString(item.Mask), item.Shift); + } + + var mask2 = (ulong)(Math.Pow(2, bitsCount) - 1); + + string returnString = string.Format("return ({0}) & {1};", tmpReturn, HttpUtilitiesGeneratorHelpers.MaskToHexString(mask2)); + + bodyString = string.Format(" const int magicNumer = {0};\r\n var tmp = (int)value & magicNumer;\r\n {1}", HttpUtilitiesGeneratorHelpers.MaskToHexString(mask), returnString); + + } + else + { + bodyString = string.Format("return (int)(value & {0});", maskHexString); + } + + return bodyString; + } + + private static string GetHttpMethodFieldName(MethodInfo methodsInfo) + { + return string.Format("_http{0}MethodLong", methodsInfo.HttpMethod.ToString()); + } + + private static string GetMaskFieldName(int nBytes) + { + return string.Format("_mask{0}Chars", nBytes); + } + + private static string GetMethodString(string method) + { + if (method == null) + { + throw new ArgumentNullException(nameof(method)); + } + + const int length = sizeof(ulong); + + if (method.Length > length) + { + throw new ArgumentException(string.Format("MethodAsciiString {0} length is greather than {1}", method, length)); + } + string result = method; + + if (result.Length == length) + { + return result; + } + + if (result.Length < length) + { + var count = length - result.Length; + + for (int i = 0; i < count; i++) + { + result += "\0"; + } + } + + return result; + } + + private class MethodInfo + { + public string MethodAsciiString; + public ulong AsciiStringAsLong; + public string HttpMethod; + public int MaskLength; + } + + private static MethodInfo GetMethodStringAndUlongAndMaskLength(Tuple method) + { + var methodString = GetMethodString(method.Item1); + + var asciiAsLong = GetAsciiStringAsLong(methodString); + + return new MethodInfo + { + MethodAsciiString = methodString, + AsciiStringAsLong = asciiAsLong, + HttpMethod = method.Item2.ToString(), + MaskLength = method.Item1.Length + }; + } + + private static byte[] GetMaskArray(int n, int length = sizeof(ulong)) + { + var maskArray = new byte[length]; + for (int i = 0; i < n; i++) + { + maskArray[i] = 0xff; + } + return maskArray; + } + + private unsafe static ulong GetAsciiStringAsLong(string str) + { + Debug.Assert(str.Length == sizeof(ulong), string.Format("String must be exactly {0} (ASCII) characters long.", sizeof(ulong))); + + var bytes = Encoding.ASCII.GetBytes(str); + + fixed (byte* ptr = &bytes[0]) + { + return *(ulong*)ptr; + } + } + } +} \ No newline at end of file diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilitiesGeneratorHelpers.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilitiesGeneratorHelpers.cs new file mode 100644 index 0000000000..1ac5de1e66 --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpUtilities/HttpUtilitiesGeneratorHelpers.cs @@ -0,0 +1,217 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace CodeGenerator.HttpUtilities +{ + internal class HttpUtilitiesGeneratorHelpers + { + public class ShiftInfo + { + public TMask Mask; + public byte Shift; + } + + public static ShiftInfo[] GetShifts(ulong mask) + { + var shifts = new List>(); + + const ulong one = 0x01; + + ulong currentMask = 0; + + int currentBitsCount = 0; + int lastShift = 0; + for (int i = 0; i < sizeof(ulong) * 8; i++) + { + var currentBitMask = one << i; + bool isCurrentBit0 = (currentBitMask & mask) == 0; + + if (isCurrentBit0 == false) + { + currentMask |= currentBitMask; + currentBitsCount++; + } + else if (currentBitsCount > 0) + { + var currentShift = (byte)(i - currentBitsCount - lastShift); + shifts.Add(new ShiftInfo + { + Mask = currentMask, + Shift = currentShift + }); + lastShift = currentShift; + currentMask = 0; + currentBitsCount = 0; + } + } + + return shifts.ToArray(); + } + + public static ulong? SearchKeyByLookThroughMaskCombinations(ulong[] values, byte bitsIndexStart, byte bitsLength, byte bitsCount) + { + if (bitsIndexStart + bitsLength > sizeof(ulong) * 8) + { + throw new ArgumentOutOfRangeException(nameof(bitsIndexStart)); + } + + if (bitsLength < bitsCount || bitsCount == 0) + { + throw new ArgumentOutOfRangeException(nameof(bitsCount)); + } + + var bits = new byte[bitsLength]; + + for (byte i = bitsIndexStart; i < bitsIndexStart + bitsLength; i++) + { + bits[i - bitsIndexStart] = i; + } + + var combinations = new CombinationsWithoutRepetition(bits, bitsCount); + + ulong? maskFound = null; + int bit1ChunksFoundMask = 0; + + int arrayLength = values.Length; + + var mashHash = new HashSet(); + + while (combinations.MoveNext()) + { + var bitsCombination = combinations.Current; + + ulong currentMask = 0; + + for (int i = 0; i < bitsCombination.Length; i++) + { + var index = bitsCombination[i]; + + const ulong oneBit = 0x01; + + currentMask |= oneBit << index; + } + + mashHash.Clear(); + bool invalidMask = false; + for (int j = 0; j < arrayLength; j++) + { + var tmp = values[j] & currentMask; + + bool alreadyExists = mashHash.Add(tmp) == false; + if (alreadyExists) + { + invalidMask = true; + break; + } + } + + if (invalidMask == false) + { + var bit1Chunks = CountBit1Chunks(currentMask); + + if (maskFound.HasValue) + { + if (bit1ChunksFoundMask > bit1Chunks) + { + maskFound = currentMask; + bit1ChunksFoundMask = bit1Chunks; + if (bit1ChunksFoundMask == 0) + { + return maskFound; + } + } + } + else + { + maskFound = currentMask; + bit1ChunksFoundMask = bit1Chunks; + + if (bit1ChunksFoundMask == 0) + { + return maskFound; + } + } + } + } + + return maskFound; + } + + public static int CountBit1Chunks(ulong mask) + { + int currentBitsCount = 0; + + int chunks = 0; + + for (int i = 0; i < sizeof(ulong) * 8; i++) + { + const ulong oneBit = 0x01; + + var currentBitMask = oneBit << i; + bool isCurrentBit0 = (currentBitMask & mask) == 0; + + if (isCurrentBit0 == false) + { + currentBitsCount++; + } + else if (currentBitsCount > 0) + { + chunks++; + currentBitsCount = 0; + } + } + + return chunks; + } + + public static string GeHexString(byte[] array, string prefix, string separator) + { + var result = new StringBuilder(); + int i = 0; + for (; i < array.Length - 1; i++) + { + result.AppendFormat("{0}{1:x2}", prefix, array[i]); + result.Append(separator); + } + + if (array.Length > 0) + { + result.AppendFormat("{0}{1:x2}", prefix, array[i]); + } + + return result.ToString(); + } + + public static string MaskToString(ulong mask) + { + var maskSizeInBIts = Math.Log(mask, 2); + var hexMaskSize = Math.Ceiling(maskSizeInBIts / 4.0); + return string.Format("0x{0:X" + hexMaskSize + "}", mask); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int CountBits(ulong v) + { + const ulong Mask01010101 = 0x5555555555555555UL; + const ulong Mask00110011 = 0x3333333333333333UL; + const ulong Mask00001111 = 0x0F0F0F0F0F0F0F0FUL; + const ulong Mask00000001 = 0x0101010101010101UL; + v = v - ((v >> 1) & Mask01010101); + v = (v & Mask00110011) + ((v >> 2) & Mask00110011); + return (int)(unchecked(((v + (v >> 4)) & Mask00001111) * Mask00000001) >> 56); + } + + public static string MaskToHexString(ulong mask) + { + var maskSizeInBIts = Math.Log(mask, 2); + var hexMaskSize = (byte)Math.Ceiling(maskSizeInBIts / 4); + + return string.Format("0x{0:X" + (hexMaskSize == 0 ? 1 : hexMaskSize) + "}", mask); + } + } +} diff --git a/src/Servers/Kestrel/tools/CodeGenerator/KnownHeaders.cs b/src/Servers/Kestrel/tools/CodeGenerator/KnownHeaders.cs new file mode 100644 index 0000000000..08646dc61a --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/KnownHeaders.cs @@ -0,0 +1,662 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; + +namespace CodeGenerator +{ + // This project can output the Class library as a NuGet Package. + // To enable this option, right-click on the project and select the Properties menu item. In the Build tab select "Produce outputs on build". + public class KnownHeaders + { + static string Each(IEnumerable values, Func formatter) + { + return values.Any() ? values.Select(formatter).Aggregate((a, b) => a + b) : ""; + } + + static string If(bool condition, Func formatter) + { + return condition ? formatter() : ""; + } + + static string AppendSwitch(IEnumerable> values, string className) => + $@"var pUL = (ulong*)pUB; + var pUI = (uint*)pUB; + var pUS = (ushort*)pUB; + var stringValue = new StringValues(value); + switch (keyLength) + {{{Each(values, byLength => $@" + case {byLength.Key}: + {{{Each(byLength, header => $@" + if ({header.EqualIgnoreCaseBytes()}) + {{{(header.Identifier == "ContentLength" ? $@" + if (_contentLength.HasValue) + {{ + BadHttpRequestException.Throw(RequestRejectionReason.MultipleContentLengths); + }} + else + {{ + _contentLength = ParseContentLength(value); + }} + return;" : $@" + if ({header.TestBit()}) + {{ + _headers._{header.Identifier} = AppendValue(_headers._{header.Identifier}, value); + }} + else + {{ + {header.SetBit()}; + _headers._{header.Identifier} = stringValue;{(header.EnhancedSetter == false ? "" : $@" + _headers._raw{header.Identifier} = null;")} + }} + return;")} + }} + ")}}} + break; + ")}}}"; + + class KnownHeader + { + public string Name { get; set; } + public int Index { get; set; } + public string Identifier => Name.Replace("-", ""); + + public byte[] Bytes => Encoding.ASCII.GetBytes($"\r\n{Name}: "); + public int BytesOffset { get; set; } + public int BytesCount { get; set; } + public bool ExistenceCheck { get; set; } + public bool FastCount { get; set; } + public bool EnhancedSetter { get; set; } + public bool PrimaryHeader { get; set; } + public string TestBit() => $"(_bits & {1L << Index}L) != 0"; + public string TestTempBit() => $"(tempBits & {1L << Index}L) != 0"; + public string TestNotTempBit() => $"(tempBits & ~{1L << Index}L) == 0"; + public string TestNotBit() => $"(_bits & {1L << Index}L) == 0"; + public string SetBit() => $"_bits |= {1L << Index}L"; + public string ClearBit() => $"_bits &= ~{1L << Index}L"; + + public string EqualIgnoreCaseBytes() + { + var result = ""; + var delim = ""; + var index = 0; + while (index != Name.Length) + { + if (Name.Length - index >= 8) + { + result += delim + Term(Name, index, 8, "pUL", "uL"); + index += 8; + } + else if (Name.Length - index >= 4) + { + result += delim + Term(Name, index, 4, "pUI", "u"); + index += 4; + } + else if (Name.Length - index >= 2) + { + result += delim + Term(Name, index, 2, "pUS", "u"); + index += 2; + } + else + { + result += delim + Term(Name, index, 1, "pUB", "u"); + index += 1; + } + delim = " && "; + } + return $"({result})"; + } + protected string Term(string name, int offset, int count, string array, string suffix) + { + ulong mask = 0; + ulong comp = 0; + for (var scan = 0; scan < count; scan++) + { + var ch = (byte)name[offset + count - scan - 1]; + var isAlpha = (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); + comp = (comp << 8) + (ch & (isAlpha ? 0xdfu : 0xffu)); + mask = (mask << 8) + (isAlpha ? 0xdfu : 0xffu); + } + return $"(({array}[{offset / count}] & {mask}{suffix}) == {comp}{suffix})"; + } + } + + public static string GeneratedFile() + { + var requestPrimaryHeaders = new[] + { + "Accept", + "Connection", + "Host", + "User-Agent" + }; + var responsePrimaryHeaders = new[] + { + "Connection", + "Date", + "Content-Type", + "Server", + }; + var commonHeaders = new[] + { + "Cache-Control", + "Connection", + "Date", + "Keep-Alive", + "Pragma", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Via", + "Warning", + "Allow", + "Content-Type", + "Content-Encoding", + "Content-Language", + "Content-Location", + "Content-MD5", + "Content-Range", + "Expires", + "Last-Modified" + }; + // http://www.w3.org/TR/cors/#syntax + var corsRequestHeaders = new[] + { + "Origin", + "Access-Control-Request-Method", + "Access-Control-Request-Headers", + }; + var requestHeadersExistence = new[] + { + "Connection", + "Transfer-Encoding", + }; + var requestHeadersCount = new[] + { + "Host" + }; + var requestHeaders = commonHeaders.Concat(new[] + { + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Authorization", + "Cookie", + "Expect", + "From", + "Host", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Range", + "If-Unmodified-Since", + "Max-Forwards", + "Proxy-Authorization", + "Referer", + "Range", + "TE", + "Translate", + "User-Agent", + }) + .Concat(corsRequestHeaders) + .Select((header, index) => new KnownHeader + { + Name = header, + Index = index, + PrimaryHeader = requestPrimaryHeaders.Contains(header), + ExistenceCheck = requestHeadersExistence.Contains(header), + FastCount = requestHeadersCount.Contains(header) + }) + .Concat(new[] { new KnownHeader + { + Name = "Content-Length", + Index = -1, + PrimaryHeader = requestPrimaryHeaders.Contains("Content-Length") + }}) + .ToArray(); + Debug.Assert(requestHeaders.Length <= 64); + Debug.Assert(requestHeaders.Max(x => x.Index) <= 62); + + var responseHeadersExistence = new[] + { + "Connection", + "Server", + "Date", + "Transfer-Encoding" + }; + var enhancedHeaders = new[] + { + "Connection", + "Server", + "Date", + "Transfer-Encoding" + }; + // http://www.w3.org/TR/cors/#syntax + var corsResponseHeaders = new[] + { + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Headers", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Expose-Headers", + "Access-Control-Max-Age", + }; + var responseHeaders = commonHeaders.Concat(new[] + { + "Accept-Ranges", + "Age", + "ETag", + "Location", + "Proxy-Authenticate", + "Retry-After", + "Server", + "Set-Cookie", + "Vary", + "WWW-Authenticate", + }) + .Concat(corsResponseHeaders) + .Select((header, index) => new KnownHeader + { + Name = header, + Index = index, + EnhancedSetter = enhancedHeaders.Contains(header), + ExistenceCheck = responseHeadersExistence.Contains(header), + PrimaryHeader = responsePrimaryHeaders.Contains(header) + }) + .Concat(new[] { new KnownHeader + { + Name = "Content-Length", + Index = -1, + EnhancedSetter = enhancedHeaders.Contains("Content-Length"), + PrimaryHeader = responsePrimaryHeaders.Contains("Content-Length") + }}) + .ToArray(); + // 63 for reponseHeaders as it steals one bit for Content-Length in CopyTo(ref MemoryPoolIterator output) + Debug.Assert(responseHeaders.Length <= 63); + Debug.Assert(responseHeaders.Max(x => x.Index) <= 62); + + var loops = new[] + { + new + { + Headers = requestHeaders, + HeadersByLength = requestHeaders.GroupBy(x => x.Name.Length), + ClassName = "HttpRequestHeaders", + Bytes = default(byte[]) + }, + new + { + Headers = responseHeaders, + HeadersByLength = responseHeaders.GroupBy(x => x.Name.Length), + ClassName = "HttpResponseHeaders", + Bytes = responseHeaders.SelectMany(header => header.Bytes).ToArray() + } + }; + foreach (var loop in loops.Where(l => l.Bytes != null)) + { + var offset = 0; + foreach (var header in loop.Headers) + { + header.BytesOffset = offset; + header.BytesCount += header.Bytes.Length; + offset += header.BytesCount; + } + } + return $@"// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using System.Buffers; +using System.IO.Pipelines; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{{ +{Each(loops, loop => $@" + public partial class {loop.ClassName} + {{{(loop.Bytes != null ? + $@" + private static byte[] _headerBytes = new byte[] + {{ + {Each(loop.Bytes, b => $"{b},")} + }};" + : "")} + + private long _bits = 0; + private HeaderReferences _headers; +{Each(loop.Headers.Where(header => header.ExistenceCheck), header => $@" + public bool Has{header.Identifier} => {header.TestBit()};")} +{Each(loop.Headers.Where(header => header.FastCount), header => $@" + public int {header.Identifier}Count => _headers._{header.Identifier}.Count;")} + {Each(loop.Headers, header => $@" + public StringValues Header{header.Identifier} + {{{(header.Identifier == "ContentLength" ? $@" + get + {{ + StringValues value; + if (_contentLength.HasValue) + {{ + value = new StringValues(HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + }} + return value; + }} + set + {{ + _contentLength = ParseContentLength(value); + }}" : $@" + get + {{ + StringValues value; + if ({header.TestBit()}) + {{ + value = _headers._{header.Identifier}; + }} + return value; + }} + set + {{ + {header.SetBit()}; + _headers._{header.Identifier} = value; {(header.EnhancedSetter == false ? "" : $@" + _headers._raw{header.Identifier} = null;")} + }}")} + }}")} +{Each(loop.Headers.Where(header => header.EnhancedSetter), header => $@" + public void SetRaw{header.Identifier}(in StringValues value, byte[] raw) + {{ + {header.SetBit()}; + _headers._{header.Identifier} = value; + _headers._raw{header.Identifier} = raw; + }}")} + protected override int GetCountFast() + {{ + return (_contentLength.HasValue ? 1 : 0 ) + BitCount(_bits) + (MaybeUnknown?.Count ?? 0); + }} + + protected override bool TryGetValueFast(string key, out StringValues value) + {{ + switch (key.Length) + {{{Each(loop.HeadersByLength, byLength => $@" + case {byLength.Key}: + {{{Each(byLength, header => $@" + if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) + {{{(header.Identifier == "ContentLength" ? @" + if (_contentLength.HasValue) + { + value = HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value); + return true; + } + return false;" : $@" + if ({header.TestBit()}) + {{ + value = _headers._{header.Identifier}; + return true; + }} + return false;")} + }}")} + }} + break;")} + }} + + return MaybeUnknown?.TryGetValue(key, out value) ?? false; + }} + + protected override void SetValueFast(string key, in StringValues value) + {{{(loop.ClassName == "HttpResponseHeaders" ? @" + ValidateHeaderCharacters(value);" : "")} + switch (key.Length) + {{{Each(loop.HeadersByLength, byLength => $@" + case {byLength.Key}: + {{{Each(byLength, header => $@" + if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) + {{{(header.Identifier == "ContentLength" ? $@" + _contentLength = ParseContentLength(value.ToString());" : $@" + {header.SetBit()}; + _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" + _headers._raw{header.Identifier} = null;")}")} + return; + }}")} + }} + break;")} + }} + + SetValueUnknown(key, value); + }} + + protected override bool AddValueFast(string key, in StringValues value) + {{{(loop.ClassName == "HttpResponseHeaders" ? @" + ValidateHeaderCharacters(value);" : "")} + switch (key.Length) + {{{Each(loop.HeadersByLength, byLength => $@" + case {byLength.Key}: + {{{Each(byLength, header => $@" + if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) + {{{(header.Identifier == "ContentLength" ? $@" + if (!_contentLength.HasValue) + {{ + _contentLength = ParseContentLength(value); + return true; + }} + return false;" : $@" + if ({header.TestNotBit()}) + {{ + {header.SetBit()}; + _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" + _headers._raw{header.Identifier} = null;")} + return true; + }} + return false;")} + }}")} + }} + break;")} + }} +{(loop.ClassName == "HttpResponseHeaders" ? @" + ValidateHeaderCharacters(key);" : "")} + Unknown.Add(key, value); + // Return true, above will throw and exit for false + return true; + }} + + protected override bool RemoveFast(string key) + {{ + switch (key.Length) + {{{Each(loop.HeadersByLength, byLength => $@" + case {byLength.Key}: + {{{Each(byLength, header => $@" + if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) + {{{(header.Identifier == "ContentLength" ? @" + if (_contentLength.HasValue) + { + _contentLength = null; + return true; + } + return false;" : $@" + if ({header.TestBit()}) + {{ + {header.ClearBit()}; + _headers._{header.Identifier} = default(StringValues);{(header.EnhancedSetter == false ? "" : $@" + _headers._raw{header.Identifier} = null;")} + return true; + }} + return false;")} + }}")} + }} + break;")} + }} + + return MaybeUnknown?.Remove(key) ?? false; + }} + + protected override void ClearFast() + {{ + MaybeUnknown?.Clear(); + _contentLength = null; + var tempBits = _bits; + _bits = 0; + if(HttpHeaders.BitCount(tempBits) > 12) + {{ + _headers = default(HeaderReferences); + return; + }} + {Each(loop.Headers.Where(header => header.Identifier != "ContentLength").OrderBy(h => !h.PrimaryHeader), header => $@" + if ({header.TestTempBit()}) + {{ + _headers._{header.Identifier} = default(StringValues); + if({header.TestNotTempBit()}) + {{ + return; + }} + tempBits &= ~{1L << header.Index}L; + }} + ")} + }} + + protected override bool CopyToFast(KeyValuePair[] array, int arrayIndex) + {{ + if (arrayIndex < 0) + {{ + return false; + }} + {Each(loop.Headers.Where(header => header.Identifier != "ContentLength"), header => $@" + if ({header.TestBit()}) + {{ + if (arrayIndex == array.Length) + {{ + return false; + }} + array[arrayIndex] = new KeyValuePair(""{header.Name}"", _headers._{header.Identifier}); + ++arrayIndex; + }}")} + if (_contentLength.HasValue) + {{ + if (arrayIndex == array.Length) + {{ + return false; + }} + array[arrayIndex] = new KeyValuePair(""Content-Length"", HeaderUtilities.FormatNonNegativeInt64(_contentLength.Value)); + ++arrayIndex; + }} + ((ICollection>)MaybeUnknown)?.CopyTo(array, arrayIndex); + + return true; + }} + {(loop.ClassName == "HttpResponseHeaders" ? $@" + internal void CopyToFast(ref BufferWriter output) + {{ + var tempBits = _bits | (_contentLength.HasValue ? {1L << 63}L : 0); + {Each(loop.Headers.Where(header => header.Identifier != "ContentLength").OrderBy(h => !h.PrimaryHeader), header => $@" + if ({header.TestTempBit()}) + {{ {(header.EnhancedSetter == false ? "" : $@" + if (_headers._raw{header.Identifier} != null) + {{ + output.Write(_headers._raw{header.Identifier}); + }} + else ")} + {{ + var valueCount = _headers._{header.Identifier}.Count; + for (var i = 0; i < valueCount; i++) + {{ + var value = _headers._{header.Identifier}[i]; + if (value != null) + {{ + output.Write(new ReadOnlySpan(_headerBytes, {header.BytesOffset}, {header.BytesCount})); + PipelineExtensions.WriteAsciiNoValidation(ref output, value); + }} + }} + }} + + if({header.TestNotTempBit()}) + {{ + return; + }} + tempBits &= ~{1L << header.Index}L; + }}{(header.Identifier == "Server" ? $@" + if ((tempBits & {1L << 63}L) != 0) + {{ + output.Write(new ReadOnlySpan(_headerBytes, {loop.Headers.First(x => x.Identifier == "ContentLength").BytesOffset}, {loop.Headers.First(x => x.Identifier == "ContentLength").BytesCount})); + PipelineExtensions.WriteNumeric(ref output, (ulong)ContentLength.Value); + + if((tempBits & ~{1L << 63}L) == 0) + {{ + return; + }} + tempBits &= ~{1L << 63}L; + }}" : "")}")} + }}" : "")} + {(loop.ClassName == "HttpRequestHeaders" ? $@" + public unsafe void Append(byte* pKeyBytes, int keyLength, string value) + {{ + var pUB = pKeyBytes; + {AppendSwitch(loop.Headers.Where(h => h.PrimaryHeader).GroupBy(x => x.Name.Length), loop.ClassName)} + + AppendNonPrimaryHeaders(pKeyBytes, keyLength, value); + }} + + private unsafe void AppendNonPrimaryHeaders(byte* pKeyBytes, int keyLength, string value) + {{ + var pUB = pKeyBytes; + {AppendSwitch(loop.Headers.Where(h => !h.PrimaryHeader).GroupBy(x => x.Name.Length), loop.ClassName)} + + AppendUnknownHeaders(pKeyBytes, keyLength, value); + }}" : "")} + + private struct HeaderReferences + {{{Each(loop.Headers.Where(header => header.Identifier != "ContentLength"), header => @" + public StringValues _" + header.Identifier + ";")} + {Each(loop.Headers.Where(header => header.EnhancedSetter), header => @" + public byte[] _raw" + header.Identifier + ";")} + }} + + public partial struct Enumerator + {{ + public bool MoveNext() + {{ + switch (_state) + {{ + {Each(loop.Headers.Where(header => header.Identifier != "ContentLength"), header => $@" + case {header.Index}: + goto state{header.Index}; + ")} + case {loop.Headers.Count()}: + goto state{loop.Headers.Count()}; + default: + goto state_default; + }} + {Each(loop.Headers.Where(header => header.Identifier != "ContentLength"), header => $@" + state{header.Index}: + if ({header.TestBit()}) + {{ + _current = new KeyValuePair(""{header.Name}"", _collection._headers._{header.Identifier}); + _state = {header.Index + 1}; + return true; + }} + ")} + state{loop.Headers.Count()}: + if (_collection._contentLength.HasValue) + {{ + _current = new KeyValuePair(""Content-Length"", HeaderUtilities.FormatNonNegativeInt64(_collection._contentLength.Value)); + _state = {loop.Headers.Count() + 1}; + return true; + }} + state_default: + if (!_hasUnknown || !_unknownEnumerator.MoveNext()) + {{ + _current = default(KeyValuePair); + return false; + }} + _current = _unknownEnumerator.Current; + return true; + }} + }} + }} +")}}}"; + } + } +} diff --git a/src/Servers/Kestrel/tools/CodeGenerator/Program.cs b/src/Servers/Kestrel/tools/CodeGenerator/Program.cs new file mode 100644 index 0000000000..a7870adec4 --- /dev/null +++ b/src/Servers/Kestrel/tools/CodeGenerator/Program.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; + +namespace CodeGenerator +{ + public class Program + { + public static int Main(string[] args) + { + if (args.Length < 1) + { + Console.Error.WriteLine("Missing path to HttpHeaders.Generated.cs"); + return 1; + } + else if (args.Length < 2) + { + Console.Error.WriteLine("Missing path to HttpProtocol.Generated.cs"); + return 1; + } + else if (args.Length < 3) + { + Console.Error.WriteLine("Missing path to HttpUtilities.Generated.cs"); + return 1; + } + + Run(args[0], args[1], args[2]); + + return 0; + } + + public static void Run(string knownHeadersPath, string httpProtocolFeatureCollectionPath, string httpUtilitiesPath) + { + var knownHeadersContent = KnownHeaders.GeneratedFile(); + var httpProtocolFeatureCollectionContent = HttpProtocolFeatureCollection.GeneratedFile("HttpProtocol"); + var httpUtilitiesContent = HttpUtilities.HttpUtilities.GeneratedFile(); + + var existingKnownHeaders = File.Exists(knownHeadersPath) ? File.ReadAllText(knownHeadersPath) : ""; + if (!string.Equals(knownHeadersContent, existingKnownHeaders)) + { + File.WriteAllText(knownHeadersPath, knownHeadersContent); + } + + var existingHttpProtocolFeatureCollection = File.Exists(httpProtocolFeatureCollectionPath) ? File.ReadAllText(httpProtocolFeatureCollectionPath) : ""; + if (!string.Equals(httpProtocolFeatureCollectionContent, existingHttpProtocolFeatureCollection)) + { + File.WriteAllText(httpProtocolFeatureCollectionPath, httpProtocolFeatureCollectionContent); + } + + var existingHttpUtilities = File.Exists(httpUtilitiesPath) ? File.ReadAllText(httpUtilitiesPath) : ""; + if (!string.Equals(httpUtilitiesContent, existingHttpUtilities)) + { + File.WriteAllText(httpUtilitiesPath, httpUtilitiesContent); + } + } + } +} diff --git a/src/Servers/Kestrel/xunit.runner.json b/src/Servers/Kestrel/xunit.runner.json new file mode 100644 index 0000000000..3a5192e57d --- /dev/null +++ b/src/Servers/Kestrel/xunit.runner.json @@ -0,0 +1,6 @@ +{ + "$schema": "http://json.schemastore.org/xunit.runner.schema", + "appDomain": "denied", + "methodDisplay": "method", + "longRunningTestSeconds": 60 +} diff --git a/src/Shared/Buffers.Testing/BufferSegment.cs b/src/Shared/Buffers.Testing/BufferSegment.cs new file mode 100644 index 0000000000..d89f4addd5 --- /dev/null +++ b/src/Shared/Buffers.Testing/BufferSegment.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Buffers +{ + internal class BufferSegment : ReadOnlySequenceSegment + { + public BufferSegment(Memory memory) + { + Memory = memory; + } + + public BufferSegment Append(Memory memory) + { + var segment = new BufferSegment(memory) + { + RunningIndex = RunningIndex + Memory.Length + }; + Next = segment; + return segment; + } + } +} diff --git a/src/Shared/Buffers.Testing/CustomMemoryForTest.cs b/src/Shared/Buffers.Testing/CustomMemoryForTest.cs new file mode 100644 index 0000000000..20406f0a99 --- /dev/null +++ b/src/Shared/Buffers.Testing/CustomMemoryForTest.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Buffers +{ + internal class CustomMemoryForTest : IMemoryOwner + { + private bool _disposed; + private T[] _array; + private readonly int _offset; + private readonly int _length; + + public CustomMemoryForTest(T[] array): this(array, 0, array.Length) + { + } + + public CustomMemoryForTest(T[] array, int offset, int length) + { + _array = array; + _offset = offset; + _length = length; + } + + public Memory Memory + { + get + { + if (_disposed) + throw new ObjectDisposedException(nameof(CustomMemoryForTest)); + return new Memory(_array, _offset, _length); + } + } + + public void Dispose() + { + if (_disposed) + return; + + _array = null; + _disposed = true; + } + } +} + diff --git a/src/Shared/Buffers.Testing/ReadOnlySequenceFactory.cs b/src/Shared/Buffers.Testing/ReadOnlySequenceFactory.cs new file mode 100644 index 0000000000..0fc0c6585f --- /dev/null +++ b/src/Shared/Buffers.Testing/ReadOnlySequenceFactory.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Text; + +namespace System.Buffers +{ + internal abstract class ReadOnlySequenceFactory + { + public static ReadOnlySequenceFactory ArrayFactory { get; } = new ArrayTestSequenceFactory(); + public static ReadOnlySequenceFactory MemoryFactory { get; } = new MemoryTestSequenceFactory(); + public static ReadOnlySequenceFactory OwnedMemoryFactory { get; } = new OwnedMemoryTestSequenceFactory(); + public static ReadOnlySequenceFactory SingleSegmentFactory { get; } = new SingleSegmentTestSequenceFactory(); + public static ReadOnlySequenceFactory SegmentPerByteFactory { get; } = new BytePerSegmentTestSequenceFactory(); + + public abstract ReadOnlySequence CreateOfSize(int size); + public abstract ReadOnlySequence CreateWithContent(byte[] data); + + public ReadOnlySequence CreateWithContent(string data) + { + return CreateWithContent(Encoding.ASCII.GetBytes(data)); + } + + internal class ArrayTestSequenceFactory : ReadOnlySequenceFactory + { + public override ReadOnlySequence CreateOfSize(int size) + { + return new ReadOnlySequence(new byte[size + 20], 10, size); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + var startSegment = new byte[data.Length + 20]; + Array.Copy(data, 0, startSegment, 10, data.Length); + return new ReadOnlySequence(startSegment, 10, data.Length); + } + } + + internal class MemoryTestSequenceFactory : ReadOnlySequenceFactory + { + public override ReadOnlySequence CreateOfSize(int size) + { + return CreateWithContent(new byte[size]); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + var startSegment = new byte[data.Length + 20]; + Array.Copy(data, 0, startSegment, 10, data.Length); + return new ReadOnlySequence(new Memory(startSegment, 10, data.Length)); + } + } + + internal class OwnedMemoryTestSequenceFactory : ReadOnlySequenceFactory + { + public override ReadOnlySequence CreateOfSize(int size) + { + return CreateWithContent(new byte[size]); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + var startSegment = new byte[data.Length + 20]; + Array.Copy(data, 0, startSegment, 10, data.Length); + return new ReadOnlySequence(new CustomMemoryForTest(startSegment, 10, data.Length).Memory); + } + } + + internal class SingleSegmentTestSequenceFactory : ReadOnlySequenceFactory + { + public override ReadOnlySequence CreateOfSize(int size) + { + return CreateWithContent(new byte[size]); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + return CreateSegments(data); + } + } + + internal class BytePerSegmentTestSequenceFactory : ReadOnlySequenceFactory + { + public override ReadOnlySequence CreateOfSize(int size) + { + return CreateWithContent(new byte[size]); + } + + public override ReadOnlySequence CreateWithContent(byte[] data) + { + var segments = new List(); + + segments.Add(Array.Empty()); + foreach (var b in data) + { + segments.Add(new[] { b }); + segments.Add(Array.Empty()); + } + + return CreateSegments(segments.ToArray()); + } + } + + public static ReadOnlySequence CreateSegments(params byte[][] inputs) + { + if (inputs == null || inputs.Length == 0) + { + throw new InvalidOperationException(); + } + + int i = 0; + + BufferSegment last = null; + BufferSegment first = null; + + do + { + byte[] s = inputs[i]; + int length = s.Length; + int dataOffset = length; + var chars = new byte[length * 2]; + + for (int j = 0; j < length; j++) + { + chars[dataOffset + j] = s[j]; + } + + // Create a segment that has offset relative to the OwnedMemory and OwnedMemory itself has offset relative to array + var memory = new Memory(chars).Slice(length, length); + + if (first == null) + { + first = new BufferSegment(memory); + last = first; + } + else + { + last = last.Append(memory); + } + i++; + } while (i < inputs.Length); + + return new ReadOnlySequence(first, 0, last, last.Memory.Length); + } + } +}