diff --git a/.vscode/launch.json b/.vscode/launch.json index 2db39e4359..db6faceb9c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -13,6 +13,30 @@ "request": "attach", "processId": "${command:pickProcess}" }, + { + "name": "Debug: TlsApp", + "type": "coreclr", + "request": "launch", + "program": "${workspaceRoot}/samples/TlsApp/bin/Debug/netcoreapp2.0/TlsApp.dll", + "cwd": "${workspaceRoot}/samples/TlsApp", + "console": "internalConsole", + "stopAtEntry": false, + "internalConsoleOptions": "openOnSessionStart", + "launchBrowser": { + "enabled": true, + "args": "https://127.0.0.1:5000", + "windows": { + "command": "cmd.exe", + "args": "/C start ${auto-detect-url}" + }, + "osx": { + "command": "open" + }, + "linux": { + "command": "xdg-open" + } + } + }, { "name": "Debug: SampleApp", "type": "coreclr", diff --git a/KestrelHttpServer.sln b/KestrelHttpServer.sln index 60221f61ae..44ad81ebae 100644 --- a/KestrelHttpServer.sln +++ b/KestrelHttpServer.sln @@ -82,6 +82,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Kestrel.Tests", "test\Kestr EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Protocols.Abstractions", "src\Protocols.Abstractions\Protocols.Abstractions.csproj", "{6956CF5C-3163-4398-8628-4ECA569245B5}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Kestrel.Tls", "src\Kestrel.Tls\Kestrel.Tls.csproj", "{924AE57C-1EBA-4A1D-A039-8C100B7507A5}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -272,12 +274,23 @@ Global {6956CF5C-3163-4398-8628-4ECA569245B5}.Release|x64.Build.0 = Release|Any CPU {6956CF5C-3163-4398-8628-4ECA569245B5}.Release|x86.ActiveCfg = Release|Any CPU {6956CF5C-3163-4398-8628-4ECA569245B5}.Release|x86.Build.0 = Release|Any CPU + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|x64.ActiveCfg = Debug|x64 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|x64.Build.0 = Debug|x64 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|x86.ActiveCfg = Debug|x86 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Debug|x86.Build.0 = Debug|x86 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|Any CPU.Build.0 = Release|Any CPU + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|x64.ActiveCfg = Release|x64 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|x64.Build.0 = Release|x64 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|x86.ActiveCfg = Release|x86 + {924AE57C-1EBA-4A1D-A039-8C100B7507A5}.Release|x86.Build.0 = Release|x86 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution - {0EF2ACDF-012F-4472-A13A-4272419E2903} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} {F510611A-3BEE-4B88-A613-5F4A74ED82A1} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} {37F3BFB2-6454-49E5-9D7F-581BF755CCFE} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} {2C3CB3DC-EEBF-4F52-9E1C-4F2F972E76C3} = {8A3D00B8-1CCF-4BE6-A060-11104CE2D9CE} @@ -286,7 +299,6 @@ Global {5F64B3C3-0C2E-431A-B820-A81BBFC863DA} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} {9559A5F1-080C-4909-B6CF-7E4B3DC55748} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} {EBFE9719-A44B-4978-A71F-D5C254E7F35A} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} - {2822C132-BFFB-4D53-AC5B-E7E47DD81A6E} = {0EF2ACDF-012F-4472-A13A-4272419E2903} {A76B8C8C-0DC5-4DD3-9B1F-02E51A0915F4} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} {56139957-5C29-4E7D-89BD-7D20598B4EAF} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} {6950B18F-A3D2-41A4-AFEC-8B7C49517611} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} @@ -294,6 +306,7 @@ Global {D95A7EC3-48AC-4D03-B2E2-0DA3E13BD3A4} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} {4F1C30F8-CCAA-48D7-9DF6-2A84021F5BCC} = {D3273454-EA07-41D2-BF0B-FCC3675C2483} {6956CF5C-3163-4398-8628-4ECA569245B5} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} + {924AE57C-1EBA-4A1D-A039-8C100B7507A5} = {2D5D5227-4DBD-499A-96B1-76A36B03B750} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {2D10D020-6770-47CA-BB8D-2C23FE3AE071} diff --git a/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs b/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs index 8f6db82a0d..823fa67b52 100644 --- a/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs +++ b/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs @@ -165,4 +165,4 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal } } } -} \ No newline at end of file +} diff --git a/src/Kestrel.Core/Adapter/Internal/RawStream.cs b/src/Kestrel.Core/Adapter/Internal/RawStream.cs index a2bd9a0078..fdd4470637 100644 --- a/src/Kestrel.Core/Adapter/Internal/RawStream.cs +++ b/src/Kestrel.Core/Adapter/Internal/RawStream.cs @@ -199,4 +199,4 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal return tcs.Task; } } -} \ No newline at end of file +} diff --git a/src/Kestrel.Core/Features/IHttp2StreamIdFeature.cs b/src/Kestrel.Core/Features/IHttp2StreamIdFeature.cs new file mode 100644 index 0000000000..30ad135062 --- /dev/null +++ b/src/Kestrel.Core/Features/IHttp2StreamIdFeature.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + public interface IHttp2StreamIdFeature + { + int StreamId { get; } + } +} diff --git a/src/Kestrel.Core/Features/ITlsApplicationProtocolFeature.cs b/src/Kestrel.Core/Features/ITlsApplicationProtocolFeature.cs new file mode 100644 index 0000000000..7ad37730d5 --- /dev/null +++ b/src/Kestrel.Core/Features/ITlsApplicationProtocolFeature.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Features +{ + // TODO: this should be merged with ITlsConnectionFeature + public interface ITlsApplicationProtocolFeature + { + string ApplicationProtocol { get; } + } +} diff --git a/src/Kestrel.Core/Internal/FrameConnection.cs b/src/Kestrel.Core/Internal/FrameConnection.cs index dbe08fce25..e1b2dbde37 100644 --- a/src/Kestrel.Core/Internal/FrameConnection.cs +++ b/src/Kestrel.Core/Internal/FrameConnection.cs @@ -13,7 +13,9 @@ using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.Extensions.Logging; @@ -21,10 +23,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { public class FrameConnection : IConnectionApplicationFeature, ITimeoutControl { + private const int Http2ConnectionNotStarted = 0; + private const int Http2ConnectionStarted = 1; + private const int Http2ConnectionClosed = 2; + private readonly FrameConnectionContext _context; private List _adaptedConnections; private readonly TaskCompletionSource _socketClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private Frame _frame; + private Http2Connection _http2Connection; + private volatile int _http2ConnectionState; private long _lastTimestamp; private long _timeoutTimestamp = long.MaxValue; @@ -116,6 +124,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal // _frame must be initialized before adding the connection to the connection manager CreateFrame(application, input, output); + // _http2Connection must be initialized before yield control to the transport thread, + // to prevent a race condition where _http2Connection.Abort() is called just as + // _http2Connection is about to be initialized. + _http2Connection = new Http2Connection(new Http2ConnectionContext + { + ConnectionId = _context.ConnectionId, + ServiceContext = _context.ServiceContext, + PipeFactory = PipeFactory, + LocalEndPoint = LocalEndPoint, + RemoteEndPoint = RemoteEndPoint, + Input = input, + Output = output + }); + // Do this before the first await so we don't yield control to the transport until we've // added the connection to the connection manager _context.ServiceContext.ConnectionManager.AddConnection(_context.FrameConnectionId, this); @@ -128,7 +150,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal adaptedPipelineTask = adaptedPipeline.RunAsync(stream); } - await _frame.ProcessRequestsAsync(); + if (_frame.ConnectionFeatures?.Get()?.ApplicationProtocol == "h2" && + Interlocked.CompareExchange(ref _http2ConnectionState, Http2ConnectionStarted, Http2ConnectionNotStarted) == Http2ConnectionNotStarted) + { + await _http2Connection.ProcessAsync(application); + } + else + { + await _frame.ProcessRequestsAsync(); + } + await adaptedPipelineTask; await _socketClosedTcs.Task; } @@ -173,10 +204,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal public void OnConnectionClosed(Exception ex) { - Debug.Assert(_frame != null, $"{nameof(_frame)} is null"); - - // Abort the connection (if not already aborted) - _frame.Abort(ex); + Abort(ex); _socketClosedTcs.TrySetResult(null); } @@ -185,7 +213,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { Debug.Assert(_frame != null, $"{nameof(_frame)} is null"); - _frame.Stop(); + if (Interlocked.Exchange(ref _http2ConnectionState, Http2ConnectionClosed) == Http2ConnectionStarted) + { + _http2Connection.Stop(); + } + else + { + _frame.Stop(); + } return _lifetimeTask; } @@ -195,15 +230,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal Debug.Assert(_frame != null, $"{nameof(_frame)} is null"); // Abort the connection (if not already aborted) - _frame.Abort(ex); + if (Interlocked.Exchange(ref _http2ConnectionState, Http2ConnectionClosed) == Http2ConnectionStarted) + { + _http2Connection.Abort(ex); + } + else + { + _frame.Abort(ex); + } } public Task AbortAsync(Exception ex) { - Debug.Assert(_frame != null, $"{nameof(_frame)} is null"); - - // Abort the connection (if not already aborted) - _frame.Abort(ex); + Abort(ex); return _lifetimeTask; } diff --git a/src/Kestrel.Core/Internal/Http/FrameRequestStream.cs b/src/Kestrel.Core/Internal/Http/FrameRequestStream.cs index 5adc0a6e71..e04966ba94 100644 --- a/src/Kestrel.Core/Internal/Http/FrameRequestStream.cs +++ b/src/Kestrel.Core/Internal/Http/FrameRequestStream.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http internal class FrameRequestStream : ReadOnlyStream { private readonly IHttpBodyControlFeature _bodyControl; - private MessageBody _body; + private IMessageBody _body; private FrameStreamState _state; private Exception _error; @@ -159,7 +159,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } - public void StartAcceptingReads(MessageBody body) + public void StartAcceptingReads(IMessageBody body) { // Only start if not aborted if (_state == FrameStreamState.Closed) diff --git a/src/Kestrel.Core/Internal/Http/IMessageBody.cs b/src/Kestrel.Core/Internal/Http/IMessageBody.cs new file mode 100644 index 0000000000..1f21fb47d1 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http/IMessageBody.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +{ + public interface IMessageBody + { + Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)); + + Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)); + } +} diff --git a/src/Kestrel.Core/Internal/Http/MessageBody.cs b/src/Kestrel.Core/Internal/Http/MessageBody.cs index 60e780375e..164f7f69ca 100644 --- a/src/Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/MessageBody.cs @@ -11,7 +11,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { - public abstract class MessageBody + public abstract class MessageBody : IMessageBody { private static readonly MessageBody _zeroContentLengthClose = new ForZeroContentLength(keepAlive: false); private static readonly MessageBody _zeroContentLengthKeepAlive = new ForZeroContentLength(keepAlive: true); diff --git a/src/Kestrel.Core/Internal/Http/RequestRejectionReason.cs b/src/Kestrel.Core/Internal/Http/RequestRejectionReason.cs index b482c87a02..0a7cb1eeb1 100644 --- a/src/Kestrel.Core/Internal/Http/RequestRejectionReason.cs +++ b/src/Kestrel.Core/Internal/Http/RequestRejectionReason.cs @@ -32,5 +32,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http MultipleHostHeaders, InvalidHostHeader, UpgradeRequestCannotHavePayload, + RequestBodyExceedsContentLength } } diff --git a/src/Kestrel.Core/Internal/Http2/HPack/DynamicTable.cs b/src/Kestrel.Core/Internal/Http2/HPack/DynamicTable.cs new file mode 100644 index 0000000000..0b58c69ba6 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/DynamicTable.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class DynamicTable + { + private readonly HeaderField[] _buffer; + private int _maxSize = 4096; + private int _size; + private int _count; + private int _insertIndex; + private int _removeIndex; + + public DynamicTable(int maxSize) + { + _buffer = new HeaderField[maxSize]; + _maxSize = maxSize; + } + + public int Count => _count; + + public int Size => _size; + + public HeaderField this[int index] + { + get + { + if (index >= _count) + { + throw new IndexOutOfRangeException(); + } + + return _buffer[_insertIndex == 0 ? _buffer.Length - 1 : _insertIndex - index - 1]; + } + } + + public void Insert(string name, string value) + { + var entrySize = name.Length + value.Length + 32; + EnsureSize(_maxSize - entrySize); + + if (_maxSize < entrySize) + { + throw new InvalidOperationException($"Unable to add entry of size {entrySize} to dynamic table of size {_maxSize}."); + } + + _buffer[_insertIndex] = new HeaderField(name, value); + _insertIndex = (_insertIndex + 1) % _buffer.Length; + _size += entrySize; + _count++; + } + + public void Resize(int maxSize) + { + _maxSize = maxSize; + EnsureSize(_maxSize); + } + + public void EnsureSize(int size) + { + while (_count > 0 && _size > size) + { + _size -= _buffer[_removeIndex].Name.Length + _buffer[_removeIndex].Value.Length + 32; + _count--; + _removeIndex = (_removeIndex + 1) % _buffer.Length; + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/HPackDecoder.cs b/src/Kestrel.Core/Internal/Http2/HPack/HPackDecoder.cs new file mode 100644 index 0000000000..33a1e014ad --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/HPackDecoder.cs @@ -0,0 +1,277 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HPackDecoder + { + private enum State + { + Ready, + HeaderFieldIndex, + HeaderNameIndex, + HeaderNameLength, + HeaderNameLengthContinue, + HeaderName, + HeaderValueLength, + HeaderValueLengthContinue, + HeaderValue, + DynamicTableSize + } + + private const byte IndexedHeaderFieldMask = 0x80; + private const byte LiteralHeaderFieldWithIncrementalIndexingMask = 0x40; + private const byte LiteralHeaderFieldWithoutIndexingMask = 0x00; + private const byte LiteralHeaderFieldNeverIndexedMask = 0x10; + private const byte DynamicTableSizeUpdateMask = 0x20; + private const byte HuffmanMask = 0x80; + + private const int IndexedHeaderFieldPrefix = 7; + private const int LiteralHeaderFieldWithIncrementalIndexingPrefix = 6; + private const int LiteralHeaderFieldWithoutIndexingPrefix = 4; + private const int LiteralHeaderFieldNeverIndexedPrefix = 4; + private const int DynamicTableSizeUpdatePrefix = 5; + private const int StringLengthPrefix = 7; + + private readonly DynamicTable _dynamicTable = new DynamicTable(4096); + private readonly IntegerDecoder _integerDecoder = new IntegerDecoder(); + + private State _state = State.Ready; + // TODO: add new HTTP/2 header size limit and allocate accordingly + private byte[] _stringOctets = new byte[Http2Frame.MinAllowedMaxFrameSize]; + private string _headerName = string.Empty; + private string _headerValue = string.Empty; + private int _stringLength; + private int _stringIndex; + private bool _index; + private bool _huffman; + + public void Decode(Span data, IHeaderDictionary headers) + { + for (var i = 0; i < data.Length; i++) + { + OnByte(data[i], headers); + } + } + + public void OnByte(byte b, IHeaderDictionary headers) + { + switch (_state) + { + case State.Ready: + if ((b & IndexedHeaderFieldMask) == IndexedHeaderFieldMask) + { + if (_integerDecoder.BeginDecode((byte)(b & ~IndexedHeaderFieldMask), IndexedHeaderFieldPrefix)) + { + OnIndexedHeaderField(_integerDecoder.Value, headers); + } + else + { + _state = State.HeaderFieldIndex; + } + } + else if ((b & LiteralHeaderFieldWithIncrementalIndexingMask) == LiteralHeaderFieldWithIncrementalIndexingMask) + { + _index = true; + var val = b & ~LiteralHeaderFieldWithIncrementalIndexingMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldWithIncrementalIndexingPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & LiteralHeaderFieldWithoutIndexingMask) == LiteralHeaderFieldWithoutIndexingMask) + { + _index = false; + var val = b & ~LiteralHeaderFieldWithoutIndexingMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldWithoutIndexingPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & LiteralHeaderFieldNeverIndexedMask) == LiteralHeaderFieldNeverIndexedMask) + { + _index = false; + var val = b & ~LiteralHeaderFieldNeverIndexedMask; + + if (val == 0) + { + _state = State.HeaderNameLength; + } + else if (_integerDecoder.BeginDecode((byte)val, LiteralHeaderFieldNeverIndexedPrefix)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + else + { + _state = State.HeaderNameIndex; + } + } + else if ((b & DynamicTableSizeUpdateMask) == DynamicTableSizeUpdateMask) + { + if (_integerDecoder.BeginDecode((byte)(b & ~DynamicTableSizeUpdateMask), DynamicTableSizeUpdatePrefix)) + { + // TODO: validate that it's less than what's defined via SETTINGS + _dynamicTable.Resize(_integerDecoder.Value); + } + else + { + _state = State.DynamicTableSize; + } + } + else + { + throw new InvalidOperationException(); + } + + break; + case State.HeaderFieldIndex: + if (_integerDecoder.Decode(b)) + { + OnIndexedHeaderField(_integerDecoder.Value, headers); + } + + break; + case State.HeaderNameIndex: + if (_integerDecoder.Decode(b)) + { + OnIndexedHeaderName(_integerDecoder.Value); + } + + break; + case State.HeaderNameLength: + _huffman = (b & HuffmanMask) == HuffmanMask; + + if (_integerDecoder.BeginDecode((byte)(b & ~HuffmanMask), StringLengthPrefix)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderName); + } + else + { + _state = State.HeaderNameLengthContinue; + } + + break; + case State.HeaderNameLengthContinue: + if (_integerDecoder.Decode(b)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderName); + } + + break; + case State.HeaderName: + _stringOctets[_stringIndex++] = b; + + if (_stringIndex == _stringLength) + { + _headerName = OnString(nextState: State.HeaderValueLength); + } + + break; + case State.HeaderValueLength: + _huffman = (b & HuffmanMask) == HuffmanMask; + + if (_integerDecoder.BeginDecode((byte)(b & ~HuffmanMask), StringLengthPrefix)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderValue); + } + else + { + _state = State.HeaderValueLengthContinue; + } + + break; + case State.HeaderValueLengthContinue: + if (_integerDecoder.Decode(b)) + { + OnStringLength(_integerDecoder.Value, nextState: State.HeaderValue); + } + + break; + case State.HeaderValue: + _stringOctets[_stringIndex++] = b; + + if (_stringIndex == _stringLength) + { + _headerValue = OnString(nextState: State.Ready); + headers.Append(_headerName, _headerValue); + + if (_index) + { + _dynamicTable.Insert(_headerName, _headerValue); + } + } + + break; + case State.DynamicTableSize: + if (_integerDecoder.Decode(b)) + { + // TODO: validate that it's less than what's defined via SETTINGS + _dynamicTable.Resize(_integerDecoder.Value); + _state = State.Ready; + } + + break; + default: + // Can't happen + throw new InvalidOperationException(); + } + } + + private void OnIndexedHeaderField(int index, IHeaderDictionary headers) + { + var header = GetHeader(index); + headers.Append(header.Name, header.Value); + _state = State.Ready; + } + + private void OnIndexedHeaderName(int index) + { + var header = GetHeader(index); + _headerName = header.Name; + _state = State.HeaderValueLength; + } + + private void OnStringLength(int length, State nextState) + { + _stringLength = length; + _stringIndex = 0; + _state = nextState; + } + + private string OnString(State nextState) + { + _state = nextState; + return _huffman + ? Huffman.Decode(_stringOctets, 0, _stringLength) + : Encoding.ASCII.GetString(_stringOctets, 0, _stringLength); + } + + private HeaderField GetHeader(int index) => index <= StaticTable.Instance.Length + ? StaticTable.Instance[index - 1] + : _dynamicTable[index - StaticTable.Instance.Length - 1]; + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/HPackEncoder.cs b/src/Kestrel.Core/Internal/Http2/HPack/HPackEncoder.cs new file mode 100644 index 0000000000..25206fed5b --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/HPackEncoder.cs @@ -0,0 +1,154 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class HPackEncoder + { + private IEnumerator> _enumerator; + + public bool BeginEncode(IEnumerable> headers, Span buffer, out int length) + { + _enumerator = headers.GetEnumerator(); + _enumerator.MoveNext(); + + return Encode(buffer, out length); + } + + public bool BeginEncode(int statusCode, IEnumerable> headers, Span buffer, out int length) + { + _enumerator = headers.GetEnumerator(); + _enumerator.MoveNext(); + + var statusCodeLength = EncodeStatusCode(statusCode, buffer); + var done = Encode(buffer.Slice(statusCodeLength), out var headersLength); + length = statusCodeLength + headersLength; + + return done; + } + + public bool Encode(Span buffer, out int length) + { + length = 0; + + do + { + if (!EncodeHeader(_enumerator.Current.Key, _enumerator.Current.Value, buffer.Slice(length), out var headerLength)) + { + return false; + } + + length += headerLength; + } while (_enumerator.MoveNext()); + + return true; + } + + private int EncodeStatusCode(int statusCode, Span buffer) + { + switch (statusCode) + { + case 200: + case 204: + case 206: + case 304: + case 400: + case 404: + case 500: + buffer[0] = (byte)(0x80 | StaticTable.Instance.StatusIndex[statusCode]); + return 1; + default: + // Send as Literal Header Field Without Indexing - Indexed Name + buffer[0] = 0x08; + + var statusBytes = StatusCodes.ToStatusBytes(statusCode); + buffer[1] = (byte)statusBytes.Length; + ((Span)statusBytes).CopyTo(buffer.Slice(2)); + + return 2 + statusBytes.Length; + } + } + + private bool EncodeHeader(string name, string value, Span buffer, out int length) + { + var i = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + buffer[i++] = 0; + + if (i == buffer.Length) + { + return false; + } + + if (!EncodeString(name, buffer.Slice(i), out var nameLength, lowercase: true)) + { + return false; + } + + i += nameLength; + + if (i >= buffer.Length) + { + return false; + } + + if (!EncodeString(value, buffer.Slice(i), out var valueLength, lowercase: false)) + { + return false; + } + + i += valueLength; + + length = i; + return true; + } + + private bool EncodeString(string s, Span buffer, out int length, bool lowercase) + { + const int toLowerMask = 0x20; + + var i = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + buffer[0] = 0; + + if (!IntegerEncoder.Encode(s.Length, 7, buffer, out var nameLength)) + { + return false; + } + + i += nameLength; + + // TODO: use huffman encoding + for (var j = 0; j < s.Length; j++) + { + if (i >= buffer.Length) + { + return false; + } + + buffer[i++] = (byte)(s[j] | (lowercase ? toLowerMask : 0)); + } + + length = i; + return true; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/HeaderField.cs b/src/Kestrel.Core/Internal/Http2/HPack/HeaderField.cs new file mode 100644 index 0000000000..9c3872cad2 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/HeaderField.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public struct HeaderField + { + public HeaderField(string name, string value) + { + Name = name; + Value = value; + } + + public string Name { get; } + public string Value { get; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/Huffman.cs b/src/Kestrel.Core/Internal/Http2/HPack/Huffman.cs new file mode 100644 index 0000000000..7c9e52f446 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/Huffman.cs @@ -0,0 +1,363 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class Huffman + { + // TODO: this can be constructed from _decodingTable + private static readonly (uint code, int bitLength)[] _encodingTable = new (uint code, int bitLength)[] + { + (0b11111111_11000000_00000000_00000000, 13), + (0b11111111_11111111_10110000_00000000, 23), + (0b11111111_11111111_11111110_00100000, 28), + (0b11111111_11111111_11111110_00110000, 28), + (0b11111111_11111111_11111110_01000000, 28), + (0b11111111_11111111_11111110_01010000, 28), + (0b11111111_11111111_11111110_01100000, 28), + (0b11111111_11111111_11111110_01110000, 28), + (0b11111111_11111111_11111110_10000000, 28), + (0b11111111_11111111_11101010_00000000, 24), + (0b11111111_11111111_11111111_11110000, 30), + (0b11111111_11111111_11111110_10010000, 28), + (0b11111111_11111111_11111110_10100000, 28), + (0b11111111_11111111_11111111_11110100, 30), + (0b11111111_11111111_11111110_10110000, 28), + (0b11111111_11111111_11111110_11000000, 28), + (0b11111111_11111111_11111110_11010000, 28), + (0b11111111_11111111_11111110_11100000, 28), + (0b11111111_11111111_11111110_11110000, 28), + (0b11111111_11111111_11111111_00000000, 28), + (0b11111111_11111111_11111111_00010000, 28), + (0b11111111_11111111_11111111_00100000, 28), + (0b11111111_11111111_11111111_11111000, 30), + (0b11111111_11111111_11111111_00110000, 28), + (0b11111111_11111111_11111111_01000000, 28), + (0b11111111_11111111_11111111_01010000, 28), + (0b11111111_11111111_11111111_01100000, 28), + (0b11111111_11111111_11111111_01110000, 28), + (0b11111111_11111111_11111111_10000000, 28), + (0b11111111_11111111_11111111_10010000, 28), + (0b11111111_11111111_11111111_10100000, 28), + (0b11111111_11111111_11111111_10110000, 28), + (0b01010000_00000000_00000000_00000000, 6), + (0b11111110_00000000_00000000_00000000, 10), + (0b11111110_01000000_00000000_00000000, 10), + (0b11111111_10100000_00000000_00000000, 12), + (0b11111111_11001000_00000000_00000000, 13), + (0b01010100_00000000_00000000_00000000, 6), + (0b11111000_00000000_00000000_00000000, 8), + (0b11111111_01000000_00000000_00000000, 11), + (0b11111110_10000000_00000000_00000000, 10), + (0b11111110_11000000_00000000_00000000, 10), + (0b11111001_00000000_00000000_00000000, 8), + (0b11111111_01100000_00000000_00000000, 11), + (0b11111010_00000000_00000000_00000000, 8), + (0b01011000_00000000_00000000_00000000, 6), + (0b01011100_00000000_00000000_00000000, 6), + (0b01100000_00000000_00000000_00000000, 6), + (0b00000000_00000000_00000000_00000000, 5), + (0b00001000_00000000_00000000_00000000, 5), + (0b00010000_00000000_00000000_00000000, 5), + (0b01100100_00000000_00000000_00000000, 6), + (0b01101000_00000000_00000000_00000000, 6), + (0b01101100_00000000_00000000_00000000, 6), + (0b01110000_00000000_00000000_00000000, 6), + (0b01110100_00000000_00000000_00000000, 6), + (0b01111000_00000000_00000000_00000000, 6), + (0b01111100_00000000_00000000_00000000, 6), + (0b10111000_00000000_00000000_00000000, 7), + (0b11111011_00000000_00000000_00000000, 8), + (0b11111111_11111000_00000000_00000000, 15), + (0b10000000_00000000_00000000_00000000, 6), + (0b11111111_10110000_00000000_00000000, 12), + (0b11111111_00000000_00000000_00000000, 10), + (0b11111111_11010000_00000000_00000000, 13), + (0b10000100_00000000_00000000_00000000, 6), + (0b10111010_00000000_00000000_00000000, 7), + (0b10111100_00000000_00000000_00000000, 7), + (0b10111110_00000000_00000000_00000000, 7), + (0b11000000_00000000_00000000_00000000, 7), + (0b11000010_00000000_00000000_00000000, 7), + (0b11000100_00000000_00000000_00000000, 7), + (0b11000110_00000000_00000000_00000000, 7), + (0b11001000_00000000_00000000_00000000, 7), + (0b11001010_00000000_00000000_00000000, 7), + (0b11001100_00000000_00000000_00000000, 7), + (0b11001110_00000000_00000000_00000000, 7), + (0b11010000_00000000_00000000_00000000, 7), + (0b11010010_00000000_00000000_00000000, 7), + (0b11010100_00000000_00000000_00000000, 7), + (0b11010110_00000000_00000000_00000000, 7), + (0b11011000_00000000_00000000_00000000, 7), + (0b11011010_00000000_00000000_00000000, 7), + (0b11011100_00000000_00000000_00000000, 7), + (0b11011110_00000000_00000000_00000000, 7), + (0b11100000_00000000_00000000_00000000, 7), + (0b11100010_00000000_00000000_00000000, 7), + (0b11100100_00000000_00000000_00000000, 7), + (0b11111100_00000000_00000000_00000000, 8), + (0b11100110_00000000_00000000_00000000, 7), + (0b11111101_00000000_00000000_00000000, 8), + (0b11111111_11011000_00000000_00000000, 13), + (0b11111111_11111110_00000000_00000000, 19), + (0b11111111_11100000_00000000_00000000, 13), + (0b11111111_11110000_00000000_00000000, 14), + (0b10001000_00000000_00000000_00000000, 6), + (0b11111111_11111010_00000000_00000000, 15), + (0b00011000_00000000_00000000_00000000, 5), + (0b10001100_00000000_00000000_00000000, 6), + (0b00100000_00000000_00000000_00000000, 5), + (0b10010000_00000000_00000000_00000000, 6), + (0b00101000_00000000_00000000_00000000, 5), + (0b10010100_00000000_00000000_00000000, 6), + (0b10011000_00000000_00000000_00000000, 6), + (0b10011100_00000000_00000000_00000000, 6), + (0b00110000_00000000_00000000_00000000, 5), + (0b11101000_00000000_00000000_00000000, 7), + (0b11101010_00000000_00000000_00000000, 7), + (0b10100000_00000000_00000000_00000000, 6), + (0b10100100_00000000_00000000_00000000, 6), + (0b10101000_00000000_00000000_00000000, 6), + (0b00111000_00000000_00000000_00000000, 5), + (0b10101100_00000000_00000000_00000000, 6), + (0b11101100_00000000_00000000_00000000, 7), + (0b10110000_00000000_00000000_00000000, 6), + (0b01000000_00000000_00000000_00000000, 5), + (0b01001000_00000000_00000000_00000000, 5), + (0b10110100_00000000_00000000_00000000, 6), + (0b11101110_00000000_00000000_00000000, 7), + (0b11110000_00000000_00000000_00000000, 7), + (0b11110010_00000000_00000000_00000000, 7), + (0b11110100_00000000_00000000_00000000, 7), + (0b11110110_00000000_00000000_00000000, 7), + (0b11111111_11111100_00000000_00000000, 15), + (0b11111111_10000000_00000000_00000000, 11), + (0b11111111_11110100_00000000_00000000, 14), + (0b11111111_11101000_00000000_00000000, 13), + (0b11111111_11111111_11111111_11000000, 28), + (0b11111111_11111110_01100000_00000000, 20), + (0b11111111_11111111_01001000_00000000, 22), + (0b11111111_11111110_01110000_00000000, 20), + (0b11111111_11111110_10000000_00000000, 20), + (0b11111111_11111111_01001100_00000000, 22), + (0b11111111_11111111_01010000_00000000, 22), + (0b11111111_11111111_01010100_00000000, 22), + (0b11111111_11111111_10110010_00000000, 23), + (0b11111111_11111111_01011000_00000000, 22), + (0b11111111_11111111_10110100_00000000, 23), + (0b11111111_11111111_10110110_00000000, 23), + (0b11111111_11111111_10111000_00000000, 23), + (0b11111111_11111111_10111010_00000000, 23), + (0b11111111_11111111_10111100_00000000, 23), + (0b11111111_11111111_11101011_00000000, 24), + (0b11111111_11111111_10111110_00000000, 23), + (0b11111111_11111111_11101100_00000000, 24), + (0b11111111_11111111_11101101_00000000, 24), + (0b11111111_11111111_01011100_00000000, 22), + (0b11111111_11111111_11000000_00000000, 23), + (0b11111111_11111111_11101110_00000000, 24), + (0b11111111_11111111_11000010_00000000, 23), + (0b11111111_11111111_11000100_00000000, 23), + (0b11111111_11111111_11000110_00000000, 23), + (0b11111111_11111111_11001000_00000000, 23), + (0b11111111_11111110_11100000_00000000, 21), + (0b11111111_11111111_01100000_00000000, 22), + (0b11111111_11111111_11001010_00000000, 23), + (0b11111111_11111111_01100100_00000000, 22), + (0b11111111_11111111_11001100_00000000, 23), + (0b11111111_11111111_11001110_00000000, 23), + (0b11111111_11111111_11101111_00000000, 24), + (0b11111111_11111111_01101000_00000000, 22), + (0b11111111_11111110_11101000_00000000, 21), + (0b11111111_11111110_10010000_00000000, 20), + (0b11111111_11111111_01101100_00000000, 22), + (0b11111111_11111111_01110000_00000000, 22), + (0b11111111_11111111_11010000_00000000, 23), + (0b11111111_11111111_11010010_00000000, 23), + (0b11111111_11111110_11110000_00000000, 21), + (0b11111111_11111111_11010100_00000000, 23), + (0b11111111_11111111_01110100_00000000, 22), + (0b11111111_11111111_01111000_00000000, 22), + (0b11111111_11111111_11110000_00000000, 24), + (0b11111111_11111110_11111000_00000000, 21), + (0b11111111_11111111_01111100_00000000, 22), + (0b11111111_11111111_11010110_00000000, 23), + (0b11111111_11111111_11011000_00000000, 23), + (0b11111111_11111111_00000000_00000000, 21), + (0b11111111_11111111_00001000_00000000, 21), + (0b11111111_11111111_10000000_00000000, 22), + (0b11111111_11111111_00010000_00000000, 21), + (0b11111111_11111111_11011010_00000000, 23), + (0b11111111_11111111_10000100_00000000, 22), + (0b11111111_11111111_11011100_00000000, 23), + (0b11111111_11111111_11011110_00000000, 23), + (0b11111111_11111110_10100000_00000000, 20), + (0b11111111_11111111_10001000_00000000, 22), + (0b11111111_11111111_10001100_00000000, 22), + (0b11111111_11111111_10010000_00000000, 22), + (0b11111111_11111111_11100000_00000000, 23), + (0b11111111_11111111_10010100_00000000, 22), + (0b11111111_11111111_10011000_00000000, 22), + (0b11111111_11111111_11100010_00000000, 23), + (0b11111111_11111111_11111000_00000000, 26), + (0b11111111_11111111_11111000_01000000, 26), + (0b11111111_11111110_10110000_00000000, 20), + (0b11111111_11111110_00100000_00000000, 19), + (0b11111111_11111111_10011100_00000000, 22), + (0b11111111_11111111_11100100_00000000, 23), + (0b11111111_11111111_10100000_00000000, 22), + (0b11111111_11111111_11110110_00000000, 25), + (0b11111111_11111111_11111000_10000000, 26), + (0b11111111_11111111_11111000_11000000, 26), + (0b11111111_11111111_11111001_00000000, 26), + (0b11111111_11111111_11111011_11000000, 27), + (0b11111111_11111111_11111011_11100000, 27), + (0b11111111_11111111_11111001_01000000, 26), + (0b11111111_11111111_11110001_00000000, 24), + (0b11111111_11111111_11110110_10000000, 25), + (0b11111111_11111110_01000000_00000000, 19), + (0b11111111_11111111_00011000_00000000, 21), + (0b11111111_11111111_11111001_10000000, 26), + (0b11111111_11111111_11111100_00000000, 27), + (0b11111111_11111111_11111100_00100000, 27), + (0b11111111_11111111_11111001_11000000, 26), + (0b11111111_11111111_11111100_01000000, 27), + (0b11111111_11111111_11110010_00000000, 24), + (0b11111111_11111111_00100000_00000000, 21), + (0b11111111_11111111_00101000_00000000, 21), + (0b11111111_11111111_11111010_00000000, 26), + (0b11111111_11111111_11111010_01000000, 26), + (0b11111111_11111111_11111111_11010000, 28), + (0b11111111_11111111_11111100_01100000, 27), + (0b11111111_11111111_11111100_10000000, 27), + (0b11111111_11111111_11111100_10100000, 27), + (0b11111111_11111110_11000000_00000000, 20), + (0b11111111_11111111_11110011_00000000, 24), + (0b11111111_11111110_11010000_00000000, 20), + (0b11111111_11111111_00110000_00000000, 21), + (0b11111111_11111111_10100100_00000000, 22), + (0b11111111_11111111_00111000_00000000, 21), + (0b11111111_11111111_01000000_00000000, 21), + (0b11111111_11111111_11100110_00000000, 23), + (0b11111111_11111111_10101000_00000000, 22), + (0b11111111_11111111_10101100_00000000, 22), + (0b11111111_11111111_11110111_00000000, 25), + (0b11111111_11111111_11110111_10000000, 25), + (0b11111111_11111111_11110100_00000000, 24), + (0b11111111_11111111_11110101_00000000, 24), + (0b11111111_11111111_11111010_10000000, 26), + (0b11111111_11111111_11101000_00000000, 23), + (0b11111111_11111111_11111010_11000000, 26), + (0b11111111_11111111_11111100_11000000, 27), + (0b11111111_11111111_11111011_00000000, 26), + (0b11111111_11111111_11111011_01000000, 26), + (0b11111111_11111111_11111100_11100000, 27), + (0b11111111_11111111_11111101_00000000, 27), + (0b11111111_11111111_11111101_00100000, 27), + (0b11111111_11111111_11111101_01000000, 27), + (0b11111111_11111111_11111101_01100000, 27), + (0b11111111_11111111_11111111_11100000, 28), + (0b11111111_11111111_11111101_10000000, 27), + (0b11111111_11111111_11111101_10100000, 27), + (0b11111111_11111111_11111101_11000000, 27), + (0b11111111_11111111_11111101_11100000, 27), + (0b11111111_11111111_11111110_00000000, 27), + (0b11111111_11111111_11111011_10000000, 26), + (0b11111111_11111111_11111111_11111100, 30) + }; + + private static readonly (int codeLength, int[] codes)[] _decodingTable = new[] + { + (5, new[] { 48, 49, 50, 97, 99, 101, 105, 111, 115, 116 }), + (6, new[] { 32, 37, 45, 46, 47, 51, 52, 53, 54, 55, 56, 57, 61, 65, 95, 98, 100, 102, 103, 104, 108, 109, 110, 112, 114, 117 }), + (7, new[] { 58, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 106, 107, 113, 118, 119, 120, 121, 122 }), + (8, new[] { 38, 42, 44, 59, 88, 90 }), + (10, new[] { 33, 34, 40, 41, 63 }), + (11, new[] { 39, 43, 124 }), + (12, new[] { 35, 62 }), + (13, new[] { 0, 36, 64, 91, 93, 126 }), + (14, new[] { 94, 125 }), + (15, new[] { 60, 96, 123 }), + (19, new[] { 92, 195, 208 }), + (20, new[] { 128, 130, 131, 162, 184, 194, 224, 226 }), + (21, new[] { 153, 161, 167, 172, 176, 177, 179, 209, 216, 217, 227, 229, 230 }), + (22, new[] { 129, 132, 133, 134, 136, 146, 154, 156, 160, 163, 164, 169, 170, 173, 178, 181, 185, 186, 187, 189, 190, 196, 198, 228, 232, 233 }), + (23, new[] { 1, 135, 137, 138, 139, 140, 141, 143, 147, 149, 150, 151, 152, 155, 157, 158, 165, 166, 168, 174, 175, 180, 182, 183, 188, 191, 197, 231, 239 }), + (24, new[] { 9, 142, 144, 145, 148, 159, 171, 206, 215, 225, 236, 237 }), + (25, new[] { 199, 207, 234, 235 }), + (26, new[] { 192, 193, 200, 201, 202, 205, 210, 213, 218, 219, 238, 240, 242, 243, 255 }), + (27, new[] { 203, 204, 211, 212, 214, 221, 222, 223, 241, 244, 245, 246, 247, 248, 250, 251, 252, 253, 254 }), + (28, new[] { 2, 3, 4, 5, 6, 7, 8, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 127, 220, 249 }), + (30, new[] { 10, 13, 22, 256 }) + }; + + public static (uint encoded, int bitLength) Encode(int data) + { + return _encodingTable[data]; + } + + public static string Decode(byte[] data, int offset, int count) + { + var sb = new StringBuilder(); + + var i = offset; + var lastDecodedBits = 0; + while (i < count) + { + var next = (uint)(data[i] << 24 + lastDecodedBits); + next |= (i + 1 < data.Length ? (uint)(data[i + 1] << 16 + lastDecodedBits) : 0); + next |= (i + 2 < data.Length ? (uint)(data[i + 2] << 8 + lastDecodedBits) : 0); + next |= (i + 3 < data.Length ? (uint)(data[i + 3] << lastDecodedBits) : 0); + + var ones = (uint)(int.MinValue >> (8 - lastDecodedBits - 1)); + if (i == count - 1 && (next & ones) == ones) + { + // Padding + break; + } + + var ch = Decode(next, out var decodedBits); + sb.Append((char)ch); + + lastDecodedBits += decodedBits; + i += lastDecodedBits / 8; + lastDecodedBits %= 8; + } + + return sb.ToString(); + } + + public static int Decode(uint data, out int decodedBits) + { + var codeMax = 0; + + for (var i = 0; i < _decodingTable.Length; i++) + { + var (codeLength, codes) = _decodingTable[i]; + var mask = int.MinValue >> (codeLength - 1); + + if (i > 0) + { + codeMax <<= codeLength - _decodingTable[i - 1].codeLength; + } + + codeMax += codes.Length; + + var masked = (data & mask) >> (32 - codeLength); + + if (masked < codeMax) + { + decodedBits = codeLength; + return codes[codes.Length - (codeMax - masked)]; + } + } + + throw new Exception(); + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/IntegerDecoder.cs b/src/Kestrel.Core/Internal/Http2/HPack/IntegerDecoder.cs new file mode 100644 index 0000000000..c3bac3eb09 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/IntegerDecoder.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class IntegerDecoder + { + private int _i; + private int _m; + + public int Value { get; private set; } + + public bool BeginDecode(byte b, int prefixLength) + { + if (b < ((1 << prefixLength) - 1)) + { + Value = b; + return true; + } + else + { + _i = b; + _m = 0; + return false; + } + } + + public bool Decode(byte b) + { + _i = _i + (b & 127) * (1 << _m); + _m = _m + 7; + + if ((b & 128) != 128) + { + Value = _i; + return true; + } + + return false; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/IntegerEncoder.cs b/src/Kestrel.Core/Internal/Http2/HPack/IntegerEncoder.cs new file mode 100644 index 0000000000..6385459d14 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/IntegerEncoder.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public static class IntegerEncoder + { + public static bool Encode(int i, int n, Span buffer, out int length) + { + var j = 0; + length = 0; + + if (buffer.Length == 0) + { + return false; + } + + if (i < (1 << n) - 1) + { + buffer[j] &= MaskHigh(8 - n); + buffer[j++] |= (byte)i; + } + else + { + buffer[j] &= MaskHigh(8 - n); + buffer[j++] |= (byte)((1 << n) - 1); + + if (j == buffer.Length) + { + return false; + } + + i = i - ((1 << n) - 1); + while (i >= 128) + { + buffer[j++] = (byte)(i % 128 + 128); + + if (j > buffer.Length) + { + return false; + } + + i = i / 128; + } + buffer[j++] = (byte)i; + } + + length = j; + return true; + } + + private static byte MaskHigh(int n) + { + return (byte)(sbyte.MinValue >> (n - 1)); + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/StaticTable.cs b/src/Kestrel.Core/Internal/Http2/HPack/StaticTable.cs new file mode 100644 index 0000000000..06612cc778 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/StaticTable.cs @@ -0,0 +1,101 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public class StaticTable + { + private static readonly StaticTable _instance = new StaticTable(); + + private readonly Dictionary _statusIndex = new Dictionary + { + [200] = 8, + [204] = 9, + [206] = 10, + [304] = 11, + [400] = 12, + [404] = 13, + [500] = 14, + }; + + private StaticTable() + { + } + + public static StaticTable Instance => _instance; + + public int Length => _staticTable.Length; + + public HeaderField this[int index] => _staticTable[index]; + + public IReadOnlyDictionary StatusIndex => _statusIndex; + + private readonly HeaderField[] _staticTable = new HeaderField[] + { + new HeaderField(":authority", ""), + new HeaderField(":method", "GET"), + new HeaderField(":method", "POST"), + new HeaderField(":path", "/"), + new HeaderField(":path", "/index.html"), + new HeaderField(":scheme", "http"), + new HeaderField(":scheme", "https"), + new HeaderField(":status", "200"), + new HeaderField(":status", "204"), + new HeaderField(":status", "206"), + new HeaderField(":status", "304"), + new HeaderField(":status", "400"), + new HeaderField(":status", "404"), + new HeaderField(":status", "500"), + new HeaderField("accept-charset", ""), + new HeaderField("accept-encoding", "gzip, deflate"), + new HeaderField("accept-language", ""), + new HeaderField("accept-ranges", ""), + new HeaderField("accept", ""), + new HeaderField("access-control-allow-origin", ""), + new HeaderField("age", ""), + new HeaderField("allow", ""), + new HeaderField("authorization", ""), + new HeaderField("cache-control", ""), + new HeaderField("content-disposition", ""), + new HeaderField("content-encoding", ""), + new HeaderField("content-language", ""), + new HeaderField("content-length", ""), + new HeaderField("content-location", ""), + new HeaderField("content-range", ""), + new HeaderField("content-type", ""), + new HeaderField("cookie", ""), + new HeaderField("date", ""), + new HeaderField("etag", ""), + new HeaderField("expect", ""), + new HeaderField("expires", ""), + new HeaderField("from", ""), + new HeaderField("host", ""), + new HeaderField("if-match", ""), + new HeaderField("if-modified-since", ""), + new HeaderField("if-none-match", ""), + new HeaderField("if-range", ""), + new HeaderField("if-unmodifiedsince", ""), + new HeaderField("last-modified", ""), + new HeaderField("link", ""), + new HeaderField("location", ""), + new HeaderField("max-forwards", ""), + new HeaderField("proxy-authenticate", ""), + new HeaderField("proxy-authorization", ""), + new HeaderField("range", ""), + new HeaderField("referer", ""), + new HeaderField("refresh", ""), + new HeaderField("retry-after", ""), + new HeaderField("server", ""), + new HeaderField("set-cookie", ""), + new HeaderField("strict-transport-security", ""), + new HeaderField("transfer-encoding", ""), + new HeaderField("user-agent", ""), + new HeaderField("vary", ""), + new HeaderField("via", ""), + new HeaderField("www-authenticate", "") + }; + } +} diff --git a/src/Kestrel.Core/Internal/Http2/HPack/StatusCodes.cs b/src/Kestrel.Core/Internal/Http2/HPack/StatusCodes.cs new file mode 100644 index 0000000000..056d5a8a1a --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/HPack/StatusCodes.cs @@ -0,0 +1,222 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack +{ + public static class StatusCodes + { + private static readonly byte[] _bytesStatus100 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status100Continue); + private static readonly byte[] _bytesStatus101 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status101SwitchingProtocols); + private static readonly byte[] _bytesStatus102 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status102Processing); + + private static readonly byte[] _bytesStatus200 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status200OK); + private static readonly byte[] _bytesStatus201 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status201Created); + private static readonly byte[] _bytesStatus202 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status202Accepted); + private static readonly byte[] _bytesStatus203 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status203NonAuthoritative); + private static readonly byte[] _bytesStatus204 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status204NoContent); + private static readonly byte[] _bytesStatus205 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status205ResetContent); + private static readonly byte[] _bytesStatus206 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status206PartialContent); + private static readonly byte[] _bytesStatus207 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status207MultiStatus); + private static readonly byte[] _bytesStatus208 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status208AlreadyReported); + private static readonly byte[] _bytesStatus226 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status226IMUsed); + + private static readonly byte[] _bytesStatus300 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status300MultipleChoices); + private static readonly byte[] _bytesStatus301 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status301MovedPermanently); + private static readonly byte[] _bytesStatus302 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status302Found); + private static readonly byte[] _bytesStatus303 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status303SeeOther); + private static readonly byte[] _bytesStatus304 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status304NotModified); + private static readonly byte[] _bytesStatus305 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status305UseProxy); + private static readonly byte[] _bytesStatus306 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status306SwitchProxy); + private static readonly byte[] _bytesStatus307 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status307TemporaryRedirect); + private static readonly byte[] _bytesStatus308 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status308PermanentRedirect); + + private static readonly byte[] _bytesStatus400 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest); + private static readonly byte[] _bytesStatus401 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status401Unauthorized); + private static readonly byte[] _bytesStatus402 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status402PaymentRequired); + private static readonly byte[] _bytesStatus403 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status403Forbidden); + private static readonly byte[] _bytesStatus404 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status404NotFound); + private static readonly byte[] _bytesStatus405 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status405MethodNotAllowed); + private static readonly byte[] _bytesStatus406 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status406NotAcceptable); + private static readonly byte[] _bytesStatus407 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status407ProxyAuthenticationRequired); + private static readonly byte[] _bytesStatus408 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status408RequestTimeout); + private static readonly byte[] _bytesStatus409 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status409Conflict); + private static readonly byte[] _bytesStatus410 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status410Gone); + private static readonly byte[] _bytesStatus411 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status411LengthRequired); + private static readonly byte[] _bytesStatus412 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status412PreconditionFailed); + private static readonly byte[] _bytesStatus413 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status413PayloadTooLarge); + private static readonly byte[] _bytesStatus414 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status414UriTooLong); + private static readonly byte[] _bytesStatus415 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status415UnsupportedMediaType); + private static readonly byte[] _bytesStatus416 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status416RangeNotSatisfiable); + private static readonly byte[] _bytesStatus417 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status417ExpectationFailed); + private static readonly byte[] _bytesStatus418 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status418ImATeapot); + private static readonly byte[] _bytesStatus419 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status419AuthenticationTimeout); + private static readonly byte[] _bytesStatus421 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status421MisdirectedRequest); + private static readonly byte[] _bytesStatus422 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status422UnprocessableEntity); + private static readonly byte[] _bytesStatus423 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status423Locked); + private static readonly byte[] _bytesStatus424 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status424FailedDependency); + private static readonly byte[] _bytesStatus426 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status426UpgradeRequired); + private static readonly byte[] _bytesStatus428 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status428PreconditionRequired); + private static readonly byte[] _bytesStatus429 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status429TooManyRequests); + private static readonly byte[] _bytesStatus431 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status431RequestHeaderFieldsTooLarge); + private static readonly byte[] _bytesStatus451 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status451UnavailableForLegalReasons); + + private static readonly byte[] _bytesStatus500 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError); + private static readonly byte[] _bytesStatus501 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status501NotImplemented); + private static readonly byte[] _bytesStatus502 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status502BadGateway); + private static readonly byte[] _bytesStatus503 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status503ServiceUnavailable); + private static readonly byte[] _bytesStatus504 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status504GatewayTimeout); + private static readonly byte[] _bytesStatus505 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status505HttpVersionNotsupported); + private static readonly byte[] _bytesStatus506 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status506VariantAlsoNegotiates); + private static readonly byte[] _bytesStatus507 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status507InsufficientStorage); + private static readonly byte[] _bytesStatus508 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status508LoopDetected); + private static readonly byte[] _bytesStatus510 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status510NotExtended); + private static readonly byte[] _bytesStatus511 = CreateStatusBytes(Microsoft.AspNetCore.Http.StatusCodes.Status511NetworkAuthenticationRequired); + + private static byte[] CreateStatusBytes(int statusCode) + { + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture)); + } + + public static byte[] ToStatusBytes(int statusCode) + { + switch (statusCode) + { + case Microsoft.AspNetCore.Http.StatusCodes.Status100Continue: + return _bytesStatus100; + case Microsoft.AspNetCore.Http.StatusCodes.Status101SwitchingProtocols: + return _bytesStatus101; + case Microsoft.AspNetCore.Http.StatusCodes.Status102Processing: + return _bytesStatus102; + + case Microsoft.AspNetCore.Http.StatusCodes.Status200OK: + return _bytesStatus200; + case Microsoft.AspNetCore.Http.StatusCodes.Status201Created: + return _bytesStatus201; + case Microsoft.AspNetCore.Http.StatusCodes.Status202Accepted: + return _bytesStatus202; + case Microsoft.AspNetCore.Http.StatusCodes.Status203NonAuthoritative: + return _bytesStatus203; + case Microsoft.AspNetCore.Http.StatusCodes.Status204NoContent: + return _bytesStatus204; + case Microsoft.AspNetCore.Http.StatusCodes.Status205ResetContent: + return _bytesStatus205; + case Microsoft.AspNetCore.Http.StatusCodes.Status206PartialContent: + return _bytesStatus206; + case Microsoft.AspNetCore.Http.StatusCodes.Status207MultiStatus: + return _bytesStatus207; + case Microsoft.AspNetCore.Http.StatusCodes.Status208AlreadyReported: + return _bytesStatus208; + case Microsoft.AspNetCore.Http.StatusCodes.Status226IMUsed: + return _bytesStatus226; + + case Microsoft.AspNetCore.Http.StatusCodes.Status300MultipleChoices: + return _bytesStatus300; + case Microsoft.AspNetCore.Http.StatusCodes.Status301MovedPermanently: + return _bytesStatus301; + case Microsoft.AspNetCore.Http.StatusCodes.Status302Found: + return _bytesStatus302; + case Microsoft.AspNetCore.Http.StatusCodes.Status303SeeOther: + return _bytesStatus303; + case Microsoft.AspNetCore.Http.StatusCodes.Status304NotModified: + return _bytesStatus304; + case Microsoft.AspNetCore.Http.StatusCodes.Status305UseProxy: + return _bytesStatus305; + case Microsoft.AspNetCore.Http.StatusCodes.Status306SwitchProxy: + return _bytesStatus306; + case Microsoft.AspNetCore.Http.StatusCodes.Status307TemporaryRedirect: + return _bytesStatus307; + case Microsoft.AspNetCore.Http.StatusCodes.Status308PermanentRedirect: + return _bytesStatus308; + + case Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest: + return _bytesStatus400; + case Microsoft.AspNetCore.Http.StatusCodes.Status401Unauthorized: + return _bytesStatus401; + case Microsoft.AspNetCore.Http.StatusCodes.Status402PaymentRequired: + return _bytesStatus402; + case Microsoft.AspNetCore.Http.StatusCodes.Status403Forbidden: + return _bytesStatus403; + case Microsoft.AspNetCore.Http.StatusCodes.Status404NotFound: + return _bytesStatus404; + case Microsoft.AspNetCore.Http.StatusCodes.Status405MethodNotAllowed: + return _bytesStatus405; + case Microsoft.AspNetCore.Http.StatusCodes.Status406NotAcceptable: + return _bytesStatus406; + case Microsoft.AspNetCore.Http.StatusCodes.Status407ProxyAuthenticationRequired: + return _bytesStatus407; + case Microsoft.AspNetCore.Http.StatusCodes.Status408RequestTimeout: + return _bytesStatus408; + case Microsoft.AspNetCore.Http.StatusCodes.Status409Conflict: + return _bytesStatus409; + case Microsoft.AspNetCore.Http.StatusCodes.Status410Gone: + return _bytesStatus410; + case Microsoft.AspNetCore.Http.StatusCodes.Status411LengthRequired: + return _bytesStatus411; + case Microsoft.AspNetCore.Http.StatusCodes.Status412PreconditionFailed: + return _bytesStatus412; + case Microsoft.AspNetCore.Http.StatusCodes.Status413PayloadTooLarge: + return _bytesStatus413; + case Microsoft.AspNetCore.Http.StatusCodes.Status414UriTooLong: + return _bytesStatus414; + case Microsoft.AspNetCore.Http.StatusCodes.Status415UnsupportedMediaType: + return _bytesStatus415; + case Microsoft.AspNetCore.Http.StatusCodes.Status416RangeNotSatisfiable: + return _bytesStatus416; + case Microsoft.AspNetCore.Http.StatusCodes.Status417ExpectationFailed: + return _bytesStatus417; + case Microsoft.AspNetCore.Http.StatusCodes.Status418ImATeapot: + return _bytesStatus418; + case Microsoft.AspNetCore.Http.StatusCodes.Status419AuthenticationTimeout: + return _bytesStatus419; + case Microsoft.AspNetCore.Http.StatusCodes.Status421MisdirectedRequest: + return _bytesStatus421; + case Microsoft.AspNetCore.Http.StatusCodes.Status422UnprocessableEntity: + return _bytesStatus422; + case Microsoft.AspNetCore.Http.StatusCodes.Status423Locked: + return _bytesStatus423; + case Microsoft.AspNetCore.Http.StatusCodes.Status424FailedDependency: + return _bytesStatus424; + case Microsoft.AspNetCore.Http.StatusCodes.Status426UpgradeRequired: + return _bytesStatus426; + case Microsoft.AspNetCore.Http.StatusCodes.Status428PreconditionRequired: + return _bytesStatus428; + case Microsoft.AspNetCore.Http.StatusCodes.Status429TooManyRequests: + return _bytesStatus429; + case Microsoft.AspNetCore.Http.StatusCodes.Status431RequestHeaderFieldsTooLarge: + return _bytesStatus431; + case Microsoft.AspNetCore.Http.StatusCodes.Status451UnavailableForLegalReasons: + return _bytesStatus451; + + case Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError: + return _bytesStatus500; + case Microsoft.AspNetCore.Http.StatusCodes.Status501NotImplemented: + return _bytesStatus501; + case Microsoft.AspNetCore.Http.StatusCodes.Status502BadGateway: + return _bytesStatus502; + case Microsoft.AspNetCore.Http.StatusCodes.Status503ServiceUnavailable: + return _bytesStatus503; + case Microsoft.AspNetCore.Http.StatusCodes.Status504GatewayTimeout: + return _bytesStatus504; + case Microsoft.AspNetCore.Http.StatusCodes.Status505HttpVersionNotsupported: + return _bytesStatus505; + case Microsoft.AspNetCore.Http.StatusCodes.Status506VariantAlsoNegotiates: + return _bytesStatus506; + case Microsoft.AspNetCore.Http.StatusCodes.Status507InsufficientStorage: + return _bytesStatus507; + case Microsoft.AspNetCore.Http.StatusCodes.Status508LoopDetected: + return _bytesStatus508; + case Microsoft.AspNetCore.Http.StatusCodes.Status510NotExtended: + return _bytesStatus510; + case Microsoft.AspNetCore.Http.StatusCodes.Status511NetworkAuthenticationRequired: + return _bytesStatus511; + + default: + return Encoding.ASCII.GetBytes(statusCode.ToString(CultureInfo.InvariantCulture)); + + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs new file mode 100644 index 0000000000..4e3476d104 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs @@ -0,0 +1,498 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Protocols; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2Connection : ITimeoutControl, IHttp2StreamLifetimeHandler + { + public static byte[] ClientPreface { get; } = Encoding.ASCII.GetBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"); + + private readonly Http2ConnectionContext _context; + private readonly Http2FrameWriter _frameWriter; + private readonly HPackDecoder _hpackDecoder; + + private readonly Http2PeerSettings _serverSettings = new Http2PeerSettings(); + private readonly Http2PeerSettings _clientSettings = new Http2PeerSettings(); + + private readonly Http2Frame _incomingFrame = new Http2Frame(); + + private Http2Stream _currentHeadersStream; + private int _lastStreamId; + + private bool _stopping; + + private readonly ConcurrentDictionary _streams = new ConcurrentDictionary(); + + public Http2Connection(Http2ConnectionContext context) + { + _context = context; + _frameWriter = new Http2FrameWriter(context.Output); + _hpackDecoder = new HPackDecoder(); + } + + public string ConnectionId => _context.ConnectionId; + + public IPipeReader Input => _context.Input; + + public IKestrelTrace Log => _context.ServiceContext.Log; + + bool ITimeoutControl.TimedOut => throw new NotImplementedException(); + + public void Abort(Exception ex) + { + _stopping = true; + _frameWriter.Abort(ex); + } + + public void Stop() + { + _stopping = true; + Input.CancelPendingRead(); + } + + public async Task ProcessAsync(IHttpApplication application) + { + Exception error = null; + var errorCode = Http2ErrorCode.NO_ERROR; + + try + { + while (!_stopping) + { + var result = await Input.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + if (ParsePreface(readableBuffer, out consumed, out examined)) + { + break; + } + } + else if (result.IsCompleted) + { + return; + } + } + finally + { + Input.Advance(consumed, examined); + } + } + + if (!_stopping) + { + await _frameWriter.WriteSettingsAsync(_serverSettings); + } + + while (!_stopping) + { + var result = await Input.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + if (Http2FrameReader.ReadFrame(readableBuffer, _incomingFrame, out consumed, out examined)) + { + Log.LogTrace($"Connection id {ConnectionId} received {_incomingFrame.Type} frame with flags 0x{_incomingFrame.Flags:x} and length {_incomingFrame.Length} for stream ID {_incomingFrame.StreamId}"); + await ProcessFrameAsync(application); + } + } + else if (result.IsCompleted) + { + return; + } + } + finally + { + Input.Advance(consumed, examined); + } + } + } + catch (ConnectionAbortedException ex) + { + // TODO: log + error = ex; + } + catch (Http2ConnectionErrorException ex) + { + // TODO: log + error = ex; + errorCode = ex.ErrorCode; + } + catch (Exception ex) + { + // TODO: log + error = ex; + errorCode = Http2ErrorCode.INTERNAL_ERROR; + } + finally + { + try + { + foreach (var stream in _streams.Values) + { + stream.Abort(error); + } + + await _frameWriter.WriteGoAwayAsync(_lastStreamId, errorCode); + } + finally + { + Input.Complete(); + _frameWriter.Abort(ex: null); + } + } + } + + private bool ParsePreface(ReadableBuffer readableBuffer, out ReadCursor consumed, out ReadCursor examined) + { + consumed = readableBuffer.Start; + examined = readableBuffer.End; + + if (readableBuffer.Length < ClientPreface.Length) + { + return false; + } + + var span = readableBuffer.IsSingleSpan + ? readableBuffer.First.Span + : readableBuffer.ToSpan(); + + for (var i = 0; i < ClientPreface.Length; i++) + { + if (ClientPreface[i] != span[i]) + { + throw new Exception("Invalid HTTP/2 connection preface."); + } + } + + consumed = examined = readableBuffer.Move(readableBuffer.Start, ClientPreface.Length); + return true; + } + + private Task ProcessFrameAsync(IHttpApplication application) + { + switch (_incomingFrame.Type) + { + case Http2FrameType.DATA: + return ProcessDataFrameAsync(); + case Http2FrameType.HEADERS: + return ProcessHeadersFrameAsync(application); + case Http2FrameType.PRIORITY: + return ProcessPriorityFrameAsync(); + case Http2FrameType.RST_STREAM: + return ProcessRstStreamFrameAsync(); + case Http2FrameType.SETTINGS: + return ProcessSettingsFrameAsync(); + case Http2FrameType.PING: + return ProcessPingFrameAsync(); + case Http2FrameType.GOAWAY: + return ProcessGoAwayFrameAsync(); + case Http2FrameType.WINDOW_UPDATE: + return ProcessWindowUpdateFrameAsync(); + case Http2FrameType.CONTINUATION: + return ProcessContinuationFrameAsync(application); + } + + return Task.CompletedTask; + } + + private Task ProcessDataFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.DataHasPadding && _incomingFrame.DataPadLength >= _incomingFrame.Length) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_streams.TryGetValue(_incomingFrame.StreamId, out var stream) && !stream.MessageBody.IsCompleted) + { + return stream.MessageBody.OnDataAsync(_incomingFrame.DataPayload, + endStream: (_incomingFrame.DataFlags & Http2DataFrameFlags.END_STREAM) == Http2DataFrameFlags.END_STREAM); + } + + return _frameWriter.WriteRstStreamAsync(_incomingFrame.StreamId, Http2ErrorCode.STREAM_CLOSED); + } + + private Task ProcessHeadersFrameAsync(IHttpApplication application) + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.HeadersHasPadding && _incomingFrame.HeadersPadLength >= _incomingFrame.Length) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + _currentHeadersStream = new Http2Stream(application, new Http2StreamContext + { + ConnectionId = ConnectionId, + StreamId = _incomingFrame.StreamId, + ServiceContext = _context.ServiceContext, + PipeFactory = _context.PipeFactory, + LocalEndPoint = _context.LocalEndPoint, + RemoteEndPoint = _context.RemoteEndPoint, + StreamLifetimeHandler = this, + FrameWriter = _frameWriter + }); + _currentHeadersStream.ExpectBody = (_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_STREAM) == 0; + _currentHeadersStream.Reset(); + + _streams[_incomingFrame.StreamId] = _currentHeadersStream; + + _hpackDecoder.Decode(_incomingFrame.HeadersPayload, _currentHeadersStream.RequestHeaders); + + if ((_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_HEADERS) == Http2HeadersFrameFlags.END_HEADERS) + { + _lastStreamId = _incomingFrame.StreamId; + _ = _currentHeadersStream.ProcessRequestAsync(); + _currentHeadersStream = null; + } + + return Task.CompletedTask; + } + + private Task ProcessPriorityFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 5) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + return Task.CompletedTask; + } + + private Task ProcessRstStreamFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 4) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + if (_streams.TryGetValue(_incomingFrame.StreamId, out var stream)) + { + stream.Abort(error: null); + } + else if (_incomingFrame.StreamId > _lastStreamId) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + return Task.CompletedTask; + } + + private Task ProcessSettingsFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.StreamId != 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if ((_incomingFrame.SettingsFlags & Http2SettingsFrameFlags.ACK) == Http2SettingsFrameFlags.ACK && _incomingFrame.Length != 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + if (_incomingFrame.Length % 6 != 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + try + { + _clientSettings.ParseFrame(_incomingFrame); + return _frameWriter.WriteSettingsAckAsync(); + } + catch (Http2SettingsParameterOutOfRangeException ex) + { + throw new Http2ConnectionErrorException(ex.Parameter == Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE + ? Http2ErrorCode.FLOW_CONTROL_ERROR + : Http2ErrorCode.PROTOCOL_ERROR); + } + } + + private Task ProcessPingFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 8) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + return _frameWriter.WritePingAsync(Http2PingFrameFlags.ACK, _incomingFrame.Payload); + } + + private Task ProcessGoAwayFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + Stop(); + return Task.CompletedTask; + } + + private Task ProcessWindowUpdateFrameAsync() + { + if (_currentHeadersStream != null) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + if (_incomingFrame.Length != 4) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.FRAME_SIZE_ERROR); + } + + if (_incomingFrame.StreamId == 0) + { + if (_incomingFrame.WindowUpdateSizeIncrement == 0) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + } + else + { + if (_incomingFrame.WindowUpdateSizeIncrement == 0) + { + return _frameWriter.WriteRstStreamAsync(_incomingFrame.StreamId, Http2ErrorCode.PROTOCOL_ERROR); + } + } + + return Task.CompletedTask; + } + + private Task ProcessContinuationFrameAsync(IHttpApplication application) + { + if (_currentHeadersStream == null || _incomingFrame.StreamId != _currentHeadersStream.StreamId) + { + throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR); + } + + _hpackDecoder.Decode(_incomingFrame.HeadersPayload, _currentHeadersStream.RequestHeaders); + + if ((_incomingFrame.ContinuationFlags & Http2ContinuationFrameFlags.END_HEADERS) == Http2ContinuationFrameFlags.END_HEADERS) + { + _lastStreamId = _currentHeadersStream.StreamId; + _ = _currentHeadersStream.ProcessRequestAsync(); + _currentHeadersStream = null; + } + + return Task.CompletedTask; + } + + void IHttp2StreamLifetimeHandler.OnStreamCompleted(int streamId) + { + _streams.TryRemove(streamId, out _); + } + + void ITimeoutControl.SetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + void ITimeoutControl.ResetTimeout(long ticks, TimeoutAction timeoutAction) + { + } + + void ITimeoutControl.CancelTimeout() + { + } + + void ITimeoutControl.StartTimingReads() + { + } + + void ITimeoutControl.PauseTimingReads() + { + } + + void ITimeoutControl.ResumeTimingReads() + { + } + + void ITimeoutControl.StopTimingReads() + { + } + + void ITimeoutControl.BytesRead(long count) + { + } + + void ITimeoutControl.StartTimingWrite(long size) + { + } + + void ITimeoutControl.StopTimingWrite() + { + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs b/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs new file mode 100644 index 0000000000..ff22a1afb0 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; +using System.Net; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2ConnectionContext + { + public string ConnectionId { get; set; } + public ServiceContext ServiceContext { get; set; } + public PipeFactory PipeFactory { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + + public IPipeReader Input { get; set; } + public IPipe Output { get; set; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2ConnectionErrorException.cs b/src/Kestrel.Core/Internal/Http2/Http2ConnectionErrorException.cs new file mode 100644 index 0000000000..b1b9136774 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2ConnectionErrorException.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2ConnectionErrorException : Exception + { + public Http2ConnectionErrorException(Http2ErrorCode errorCode) + : base($"HTTP/2 connection error: {errorCode}") + { + ErrorCode = errorCode; + } + + public Http2ErrorCode ErrorCode { get; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2ContinuationFrameFlags.cs b/src/Kestrel.Core/Internal/Http2/Http2ContinuationFrameFlags.cs new file mode 100644 index 0000000000..65e65bc0bc --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2ContinuationFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2ContinuationFrameFlags : byte + { + NONE = 0x0, + END_HEADERS = 0x4, + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2DataFrameFlags.cs b/src/Kestrel.Core/Internal/Http2/Http2DataFrameFlags.cs new file mode 100644 index 0000000000..735a4aea30 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2DataFrameFlags.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2DataFrameFlags : byte + { + NONE = 0x0, + END_STREAM = 0x1, + PADDED = 0x8 + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2ErrorCode.cs b/src/Kestrel.Core/Internal/Http2/Http2ErrorCode.cs new file mode 100644 index 0000000000..2f14f191ba --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2ErrorCode.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2ErrorCode : uint + { + NO_ERROR = 0x0, + PROTOCOL_ERROR = 0x1, + INTERNAL_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + SETTINGS_TIMEOUT = 0x4, + STREAM_CLOSED = 0x5, + FRAME_SIZE_ERROR = 0x6, + REFUSED_STREAM = 0x7, + CANCEL = 0x8, + COMPRESSION_ERROR = 0x9, + CONNECT_ERROR = 0xa, + ENHANCE_YOUR_CALM = 0xb, + INADEQUATE_SECURITY = 0xc, + HTTP_1_1_REQUIRED = 0xd, + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Continuation.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Continuation.cs new file mode 100644 index 0000000000..e184ff64db --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Continuation.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2ContinuationFrameFlags ContinuationFlags + { + get => (Http2ContinuationFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PrepareContinuation(Http2ContinuationFrameFlags flags, int streamId) + { + Length = MinAllowedMaxFrameSize - HeaderLength; + Type = Http2FrameType.CONTINUATION; + ContinuationFlags = flags; + StreamId = streamId; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Data.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Data.cs new file mode 100644 index 0000000000..78adb78c57 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Data.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2DataFrameFlags DataFlags + { + get => (Http2DataFrameFlags)Flags; + set => Flags = (byte)value; + } + + public bool DataHasPadding => (DataFlags & Http2DataFrameFlags.PADDED) == Http2DataFrameFlags.PADDED; + + public byte DataPadLength + { + get => DataHasPadding ? _data[PayloadOffset] : (byte)0; + set => _data[PayloadOffset] = value; + } + + public ArraySegment DataPayload => DataHasPadding + ? new ArraySegment(_data, PayloadOffset + 1, Length - DataPadLength - 1) + : new ArraySegment(_data, PayloadOffset, Length); + + public void PrepareData(int streamId, byte? padLength = null) + { + var padded = padLength != null; + + Length = MinAllowedMaxFrameSize - HeaderLength; + Type = Http2FrameType.DATA; + DataFlags = padded ? Http2DataFrameFlags.PADDED : Http2DataFrameFlags.NONE; + StreamId = streamId; + + if (padded) + { + DataPadLength = padLength.Value; + Payload.Slice(Length - padLength.Value).Fill(0); + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.GoAway.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.GoAway.cs new file mode 100644 index 0000000000..3e430de8b0 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.GoAway.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int GoAwayLastStreamId + { + get => (Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]; + set + { + Payload[0] = (byte)((value >> 24) & 0xff); + Payload[1] = (byte)((value >> 16) & 0xff); + Payload[2] = (byte)((value >> 8) & 0xff); + Payload[3] = (byte)(value & 0xff); + } + } + + public Http2ErrorCode GoAwayErrorCode + { + get => (Http2ErrorCode)((Payload[4] << 24) | (Payload[5] << 16) | (Payload[6] << 16) | Payload[7]); + set + { + Payload[4] = (byte)(((uint)value >> 24) & 0xff); + Payload[5] = (byte)(((uint)value >> 16) & 0xff); + Payload[6] = (byte)(((uint)value >> 8) & 0xff); + Payload[7] = (byte)((uint)value & 0xff); + } + } + + public void PrepareGoAway(int lastStreamId, Http2ErrorCode errorCode) + { + Length = 8; + Type = Http2FrameType.GOAWAY; + Flags = 0; + StreamId = 0; + GoAwayLastStreamId = lastStreamId; + GoAwayErrorCode = errorCode; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Headers.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Headers.cs new file mode 100644 index 0000000000..52c20bde94 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Headers.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2HeadersFrameFlags HeadersFlags + { + get => (Http2HeadersFrameFlags)Flags; + set => Flags = (byte)value; + } + + public bool HeadersHasPadding => (HeadersFlags & Http2HeadersFrameFlags.PADDED) == Http2HeadersFrameFlags.PADDED; + + public byte HeadersPadLength + { + get => HeadersHasPadding ? _data[HeaderLength] : (byte)0; + set => _data[HeaderLength] = value; + } + + public bool HeadersHasPriority => (HeadersFlags & Http2HeadersFrameFlags.PRIORITY) == Http2HeadersFrameFlags.PRIORITY; + + public byte HeadersPriority + { + get => _data[HeadersPriorityOffset]; + set => _data[HeadersPriorityOffset] = value; + } + + private int HeadersPriorityOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0) + 4; + + public int HeadersStreamDependency + { + get + { + var offset = HeadersStreamDependencyOffset; + + return (int)((uint)((_data[offset] << 24) + | (_data[offset + 1] << 16) + | (_data[offset + 2] << 8) + | _data[offset + 3]) & 0x7fffffff); + } + set + { + var offset = HeadersStreamDependencyOffset; + + _data[offset] = (byte)((value & 0xff000000) >> 24); + _data[offset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[offset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[offset + 3] = (byte)(value & 0x000000ff); + } + } + + private int HeadersStreamDependencyOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0); + + public Span HeadersPayload => new Span(_data, HeadersPayloadOffset, HeadersPayloadLength); + + private int HeadersPayloadOffset => PayloadOffset + (HeadersHasPadding ? 1 : 0) + (HeadersHasPriority ? 5 : 0); + + private int HeadersPayloadLength => Length - ((HeadersHasPadding ? 1 : 0) + (HeadersHasPriority ? 5 : 0)) - HeadersPadLength; + + public void PrepareHeaders(Http2HeadersFrameFlags flags, int streamId) + { + Length = MinAllowedMaxFrameSize - HeaderLength; + Type = Http2FrameType.HEADERS; + HeadersFlags = flags; + StreamId = streamId; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Ping.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Ping.cs new file mode 100644 index 0000000000..932d717ac1 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Ping.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2PingFrameFlags PingFlags + { + get => (Http2PingFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PreparePing(Http2PingFrameFlags flags) + { + Length = 8; + Type = Http2FrameType.PING; + PingFlags = flags; + StreamId = 0; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Priority.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Priority.cs new file mode 100644 index 0000000000..02f9bf02f9 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Priority.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int PriorityStreamDependency + { + get => ((_data[PayloadOffset] << 24) + | (_data[PayloadOffset + 1] << 16) + | (_data[PayloadOffset + 2] << 8) + | _data[PayloadOffset + 3]) & 0x7fffffff; + set + { + _data[PayloadOffset] = (byte)((value & 0x7f000000) >> 24); + _data[PayloadOffset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[PayloadOffset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[PayloadOffset + 3] = (byte)(value & 0x000000ff); + } + } + + + public bool PriorityIsExclusive + { + get => (_data[PayloadOffset] & 0x80000000) != 0; + set + { + if (value) + { + _data[PayloadOffset] |= 0x80; + } + else + { + _data[PayloadOffset] &= 0x7f; + } + } + } + + public byte PriorityWeight + { + get => _data[PayloadOffset + 4]; + set => _data[PayloadOffset] = value; + } + + + public void PreparePriority(int streamId, int streamDependency, bool exclusive, byte weight) + { + Length = 5; + Type = Http2FrameType.PRIORITY; + StreamId = streamId; + PriorityStreamDependency = streamDependency; + PriorityIsExclusive = exclusive; + PriorityWeight = weight; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.RstStream.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.RstStream.cs new file mode 100644 index 0000000000..8a0bcdfd6c --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.RstStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2ErrorCode RstStreamErrorCode + { + get => (Http2ErrorCode)((Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]); + set + { + Payload[0] = (byte)(((uint)value >> 24) & 0xff); + Payload[1] = (byte)(((uint)value >> 16) & 0xff); + Payload[2] = (byte)(((uint)value >> 8) & 0xff); + Payload[3] = (byte)((uint)value & 0xff); + } + } + + public void PrepareRstStream(int streamId, Http2ErrorCode errorCode) + { + Length = 4; + Type = Http2FrameType.RST_STREAM; + Flags = 0; + StreamId = streamId; + RstStreamErrorCode = errorCode; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.Settings.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.Settings.cs new file mode 100644 index 0000000000..04cc78b209 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.Settings.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public Http2SettingsFrameFlags SettingsFlags + { + get => (Http2SettingsFrameFlags)Flags; + set => Flags = (byte)value; + } + + public void PrepareSettings(Http2SettingsFrameFlags flags, Http2PeerSettings settings = null) + { + var settingCount = settings?.Count() ?? 0; + + Length = 6 * settingCount; + Type = Http2FrameType.SETTINGS; + SettingsFlags = flags; + StreamId = 0; + + if (settings != null) + { + Span payload = Payload; + foreach (var setting in settings) + { + payload[0] = (byte)((ushort)setting.Parameter >> 8); + payload[1] = (byte)(ushort)setting.Parameter; + payload[2] = (byte)(setting.Value >> 24); + payload[3] = (byte)(setting.Value >> 16); + payload[4] = (byte)(setting.Value >> 8); + payload[5] = (byte)setting.Value; + payload = payload.Slice(6); + } + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.WindowUpdate.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.WindowUpdate.cs new file mode 100644 index 0000000000..6958b376f5 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.WindowUpdate.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public int WindowUpdateSizeIncrement + { + get => ((Payload[0] << 24) | (Payload[1] << 16) | (Payload[2] << 16) | Payload[3]) & 0x7fffffff; + set + { + Payload[0] = (byte)(((uint)value >> 24) & 0x7f); + Payload[1] = (byte)(((uint)value >> 16) & 0xff); + Payload[2] = (byte)(((uint)value >> 8) & 0xff); + Payload[3] = (byte)((uint)value & 0xff); + } + } + + public void PrepareWindowUpdate(int streamId, int sizeIncrement) + { + Length = 4; + Type = Http2FrameType.WINDOW_UPDATE; + Flags = 0; + StreamId = streamId; + WindowUpdateSizeIncrement = sizeIncrement; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Frame.cs b/src/Kestrel.Core/Internal/Http2/Http2Frame.cs new file mode 100644 index 0000000000..3e9e15750d --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Frame.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Frame + { + public const int MinAllowedMaxFrameSize = 16 * 1024; + public const int MaxAllowedMaxFrameSize = 16 * 1024 * 1024 - 1; + public const int HeaderLength = 9; + + private const int LengthOffset = 0; + private const int TypeOffset = 3; + private const int FlagsOffset = 4; + private const int StreamIdOffset = 5; + private const int PayloadOffset = 9; + + private readonly byte[] _data = new byte[MinAllowedMaxFrameSize]; + + public ArraySegment Raw => new ArraySegment(_data, 0, HeaderLength + Length); + + public int Length + { + get => (_data[LengthOffset] << 16) | (_data[LengthOffset + 1] << 8) | _data[LengthOffset + 2]; + set + { + _data[LengthOffset] = (byte)((value & 0x00ff0000) >> 16); + _data[LengthOffset + 1] = (byte)((value & 0x0000ff00) >> 8); + _data[LengthOffset + 2] = (byte)(value & 0x000000ff); + } + } + + public Http2FrameType Type + { + get => (Http2FrameType)_data[TypeOffset]; + set + { + _data[TypeOffset] = (byte)value; + } + } + + public byte Flags + { + get => _data[FlagsOffset]; + set + { + _data[FlagsOffset] = (byte)value; + } + } + + public int StreamId + { + get => (int)((uint)((_data[StreamIdOffset] << 24) + | (_data[StreamIdOffset + 1] << 16) + | (_data[StreamIdOffset + 2] << 8) + | _data[StreamIdOffset + 3]) & 0x7fffffff); + set + { + _data[StreamIdOffset] = (byte)((value & 0xff000000) >> 24); + _data[StreamIdOffset + 1] = (byte)((value & 0x00ff0000) >> 16); + _data[StreamIdOffset + 2] = (byte)((value & 0x0000ff00) >> 8); + _data[StreamIdOffset + 3] = (byte)(value & 0x000000ff); + } + } + + public Span Payload => new Span(_data, PayloadOffset, Length); + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2FrameReader.cs b/src/Kestrel.Core/Internal/Http2/Http2FrameReader.cs new file mode 100644 index 0000000000..2a45009277 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2FrameReader.cs @@ -0,0 +1,34 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public static class Http2FrameReader + { + public static bool ReadFrame(ReadableBuffer readableBuffer, Http2Frame frame, out ReadCursor consumed, out ReadCursor examined) + { + consumed = readableBuffer.Start; + examined = readableBuffer.End; + + if (readableBuffer.Length < Http2Frame.HeaderLength) + { + return false; + } + + var headerSlice = readableBuffer.Slice(0, Http2Frame.HeaderLength); + headerSlice.CopyTo(frame.Raw); + + if (readableBuffer.Length < Http2Frame.HeaderLength + frame.Length) + { + return false; + } + + readableBuffer.Slice(Http2Frame.HeaderLength, frame.Length).CopyTo(frame.Payload); + consumed = examined = readableBuffer.Move(readableBuffer.Start, Http2Frame.HeaderLength + frame.Length); + + return true; + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2FrameType.cs b/src/Kestrel.Core/Internal/Http2/Http2FrameType.cs new file mode 100644 index 0000000000..a09272a6be --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2FrameType.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2FrameType : byte + { + DATA = 0x0, + HEADERS = 0x1, + PRIORITY = 0x2, + RST_STREAM = 0x3, + SETTINGS = 0x4, + PUSH_PROMISE = 0x5, + PING = 0x6, + GOAWAY = 0x7, + WINDOW_UPDATE = 0x8, + CONTINUATION = 0x9 + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs new file mode 100644 index 0000000000..698f6a19c5 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs @@ -0,0 +1,192 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2FrameWriter : IHttp2FrameWriter + { + private static readonly ArraySegment _emptyData = new ArraySegment(new byte[0]); + + private readonly Http2Frame _outgoingFrame = new Http2Frame(); + private readonly object _writeLock = new object(); + private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); + private readonly IPipe _output; + + private bool _completed; + + public Http2FrameWriter(IPipe output) + { + _output = output; + } + + public void Abort(Exception ex) + { + lock (_writeLock) + { + _completed = true; + _output.Reader.CancelPendingRead(); + _output.Writer.Complete(ex); + } + } + + public Task FlushAsync(CancellationToken cancellationToken) + { + return WriteAsync(_emptyData); + } + + public Task Write100ContinueAsync(int streamId) + { + return Task.CompletedTask; + } + + public Task WriteHeadersAsync(int streamId, int statusCode, IHeaderDictionary headers) + { + var tasks = new List(); + + lock (_writeLock) + { + _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.NONE, streamId); + + var done = _hpackEncoder.BeginEncode(statusCode, EnumerateHeaders(headers), _outgoingFrame.Payload, out var payloadLength); + _outgoingFrame.Length = payloadLength; + + if (done) + { + _outgoingFrame.HeadersFlags = Http2HeadersFrameFlags.END_HEADERS; + } + + tasks.Add(WriteAsync(_outgoingFrame.Raw)); + + while (!done) + { + _outgoingFrame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + + done = _hpackEncoder.Encode(_outgoingFrame.Payload, out var length); + _outgoingFrame.Length = length; + + if (done) + { + _outgoingFrame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; + } + + tasks.Add(WriteAsync(_outgoingFrame.Raw)); + } + + return Task.WhenAll(tasks); + } + } + + public Task WriteDataAsync(int streamId, Span data, CancellationToken cancellationToken) + => WriteDataAsync(streamId, data, endStream: false, cancellationToken: cancellationToken); + + public Task WriteDataAsync(int streamId, Span data, bool endStream, CancellationToken cancellationToken) + { + var tasks = new List(); + + lock (_writeLock) + { + _outgoingFrame.PrepareData(streamId); + + while (data.Length > _outgoingFrame.Length) + { + data.Slice(0, _outgoingFrame.Length).CopyTo(_outgoingFrame.Payload); + data = data.Slice(_outgoingFrame.Length); + + tasks.Add(WriteAsync(_outgoingFrame.Raw, cancellationToken)); + } + + _outgoingFrame.Length = data.Length; + + if (endStream) + { + _outgoingFrame.DataFlags = Http2DataFrameFlags.END_STREAM; + } + + data.CopyTo(_outgoingFrame.Payload); + + tasks.Add(WriteAsync(_outgoingFrame.Raw, cancellationToken)); + + return Task.WhenAll(tasks); + } + } + + public Task WriteRstStreamAsync(int streamId, Http2ErrorCode errorCode) + { + lock (_writeLock) + { + _outgoingFrame.PrepareRstStream(streamId, errorCode); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteSettingsAsync(Http2PeerSettings settings) + { + lock (_writeLock) + { + // TODO: actually send settings + _outgoingFrame.PrepareSettings(Http2SettingsFrameFlags.NONE); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteSettingsAckAsync() + { + lock (_writeLock) + { + _outgoingFrame.PrepareSettings(Http2SettingsFrameFlags.ACK); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WritePingAsync(Http2PingFrameFlags flags, Span payload) + { + lock (_writeLock) + { + _outgoingFrame.PreparePing(Http2PingFrameFlags.ACK); + payload.CopyTo(_outgoingFrame.Payload); + return WriteAsync(_outgoingFrame.Raw); + } + } + + public Task WriteGoAwayAsync(int lastStreamId, Http2ErrorCode errorCode) + { + lock (_writeLock) + { + _outgoingFrame.PrepareGoAway(lastStreamId, errorCode); + return WriteAsync(_outgoingFrame.Raw); + } + } + + private async Task WriteAsync(ArraySegment data, CancellationToken cancellationToken = default(CancellationToken)) + { + if (_completed) + { + return; + } + + var writeableBuffer = _output.Writer.Alloc(1); + writeableBuffer.Write(data); + await writeableBuffer.FlushAsync(cancellationToken); + } + + private static IEnumerable> EnumerateHeaders(IHeaderDictionary headers) + { + foreach (var header in headers) + { + foreach (var value in header.Value) + { + yield return new KeyValuePair(header.Key, value); + } + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2HeadersFrameFlags.cs b/src/Kestrel.Core/Internal/Http2/Http2HeadersFrameFlags.cs new file mode 100644 index 0000000000..564371e1be --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2HeadersFrameFlags.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2HeadersFrameFlags : byte + { + NONE = 0x0, + END_STREAM = 0x1, + END_HEADERS = 0x4, + PADDED = 0x8, + PRIORITY = 0x20 + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs new file mode 100644 index 0000000000..81ab3a5595 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs @@ -0,0 +1,262 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public abstract class Http2MessageBody : IMessageBody + { + private static readonly Http2MessageBody _emptyMessageBody = new ForEmpty(); + + private readonly Http2Stream _context; + + private bool _send100Continue = true; + + protected Http2MessageBody(Http2Stream context) + { + _context = context; + } + + public bool IsCompleted { get; protected set; } + + public virtual async Task OnDataAsync(ArraySegment data, bool endStream) + { + try + { + if (data.Count > 0) + { + var writableBuffer = _context.RequestBodyPipe.Writer.Alloc(1); + bool done; + + try + { + done = Read(data, writableBuffer); + } + finally + { + writableBuffer.Commit(); + } + + await writableBuffer.FlushAsync(); + } + + if (endStream) + { + IsCompleted = true; + _context.RequestBodyPipe.Writer.Complete(); + } + } + catch (Exception ex) + { + _context.RequestBodyPipe.Writer.Complete(ex); + } + } + + public virtual async Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) + { + TryInit(); + + while (true) + { + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + // buffer.Count is int + var actual = (int)Math.Min(readableBuffer.Length, buffer.Count); + var slice = readableBuffer.Slice(0, actual); + consumed = readableBuffer.Move(readableBuffer.Start, actual); + slice.CopyTo(buffer); + return actual; + } + else if (result.IsCompleted) + { + return 0; + } + } + finally + { + _context.RequestBodyPipe.Reader.Advance(consumed); + } + } + } + + public virtual async Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) + { + TryInit(); + + while (true) + { + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; + + try + { + if (!readableBuffer.IsEmpty) + { + foreach (var memory in readableBuffer) + { + var array = memory.GetArray(); + await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken); + } + } + else if (result.IsCompleted) + { + return; + } + } + finally + { + _context.RequestBodyPipe.Reader.Advance(consumed); + } + } + } + + public virtual Task StopAsync() + { + _context.RequestBodyPipe.Reader.Complete(); + _context.RequestBodyPipe.Writer.Complete(); + return Task.CompletedTask; + } + + protected void Copy(Span data, WritableBuffer writableBuffer) + { + writableBuffer.Write(data); + } + + private void TryProduceContinue() + { + if (_send100Continue) + { + _context.HttpStreamControl.ProduceContinue(); + _send100Continue = false; + } + } + + private void TryInit() + { + if (!_context.HasStartedConsumingRequestBody) + { + OnReadStart(); + _context.HasStartedConsumingRequestBody = true; + } + } + + protected virtual bool Read(Span readableBuffer, WritableBuffer writableBuffer) + { + throw new NotImplementedException(); + } + + protected virtual void OnReadStart() + { + } + + public static Http2MessageBody For( + FrameRequestHeaders headers, + Http2Stream context) + { + if (!context.ExpectBody) + { + return _emptyMessageBody; + } + + if (headers.ContentLength.HasValue) + { + var contentLength = headers.ContentLength.Value; + + return new ForContentLength(contentLength, context); + } + + return new ForRemainingData(context); + } + + private class ForEmpty : Http2MessageBody + { + public ForEmpty() + : base(context: null) + { + IsCompleted = true; + } + + public override Task OnDataAsync(ArraySegment data, bool endStream) + { + throw new NotImplementedException(); + } + + public override Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + return Task.FromResult(0); + } + + public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } + + private class ForRemainingData : Http2MessageBody + { + public ForRemainingData(Http2Stream context) + : base(context) + { + } + + protected override bool Read(Span data, WritableBuffer writableBuffer) + { + Copy(data, writableBuffer); + return false; + } + } + + private class ForContentLength : Http2MessageBody + { + private readonly long _contentLength; + private long _inputLength; + + public ForContentLength(long contentLength, Http2Stream context) + : base(context) + { + _contentLength = contentLength; + _inputLength = _contentLength; + } + + protected override bool Read(Span data, WritableBuffer writableBuffer) + { + if (_inputLength == 0) + { + throw new InvalidOperationException("Attempted to read from completed Content-Length request body."); + } + + if (data.Length > _inputLength) + { + _context.ThrowRequestRejected(RequestRejectionReason.RequestBodyExceedsContentLength); + } + + _inputLength -= data.Length; + + Copy(data, writableBuffer); + + return _inputLength == 0; + } + + protected override void OnReadStart() + { + if (_contentLength > _context.MaxRequestBodySize) + { + _context.ThrowRequestRejected(RequestRejectionReason.RequestBodyTooLarge); + } + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2PeerSetting.cs b/src/Kestrel.Core/Internal/Http2/Http2PeerSetting.cs new file mode 100644 index 0000000000..f21b3ca929 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2PeerSetting.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public struct Http2PeerSetting + { + public Http2PeerSetting(Http2SettingsParameter parameter, uint value) + { + Parameter = parameter; + Value = value; + } + + public Http2SettingsParameter Parameter { get; } + + public uint Value { get; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2PeerSettings.cs b/src/Kestrel.Core/Internal/Http2/Http2PeerSettings.cs new file mode 100644 index 0000000000..af94d593df --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2PeerSettings.cs @@ -0,0 +1,106 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2PeerSettings : IEnumerable + { + public const uint DefaultHeaderTableSize = 4096; + public const bool DefaultEnablePush = true; + public const uint DefaultMaxConcurrentStreams = uint.MaxValue; + public const uint DefaultInitialWindowSize = 65535; + public const uint DefaultMaxFrameSize = 16384; + public const uint DefaultMaxHeaderListSize = uint.MaxValue; + + public uint HeaderTableSize { get; set; } = DefaultHeaderTableSize; + + public bool EnablePush { get; set; } = DefaultEnablePush; + + public uint MaxConcurrentStreams { get; set; } = DefaultMaxConcurrentStreams; + + public uint InitialWindowSize { get; set; } = DefaultInitialWindowSize; + + public uint MaxFrameSize { get; set; } = DefaultMaxFrameSize; + + public uint MaxHeaderListSize { get; set; } = DefaultMaxHeaderListSize; + + public void ParseFrame(Http2Frame frame) + { + var settingsCount = frame.Length / 6; + + for (var i = 0; i < settingsCount; i++) + { + var offset = i * 6; + var id = (Http2SettingsParameter)((frame.Payload[offset] << 8) | frame.Payload[offset + 1]); + var value = (uint)((frame.Payload[offset + 2] << 24) + | (frame.Payload[offset + 3] << 16) + | (frame.Payload[offset + 4] << 8) + | frame.Payload[offset + 5]); + + switch (id) + { + case Http2SettingsParameter.SETTINGS_HEADER_TABLE_SIZE: + HeaderTableSize = value; + break; + case Http2SettingsParameter.SETTINGS_ENABLE_PUSH: + if (value != 0 && value != 1) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, + lowerBound: 0, + upperBound: 1); + } + + EnablePush = value == 1; + break; + case Http2SettingsParameter.SETTINGS_MAX_CONCURRENT_STREAMS: + MaxConcurrentStreams = value; + break; + case Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE: + if (value > int.MaxValue) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, + lowerBound: 0, + upperBound: int.MaxValue); + } + + InitialWindowSize = value; + break; + case Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE: + if (value < Http2Frame.MinAllowedMaxFrameSize || value > Http2Frame.MaxAllowedMaxFrameSize) + { + throw new Http2SettingsParameterOutOfRangeException(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, + lowerBound: Http2Frame.MinAllowedMaxFrameSize, + upperBound: Http2Frame.MaxAllowedMaxFrameSize); + } + + MaxFrameSize = value; + break; + case Http2SettingsParameter.SETTINGS_MAX_HEADER_LIST_SIZE: + MaxHeaderListSize = value; + break; + default: + // http://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2 + // + // An endpoint that receives a SETTINGS frame with any unknown or unsupported identifier MUST ignore that setting. + break; + } + } + } + + public IEnumerator GetEnumerator() + { + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_HEADER_TABLE_SIZE, HeaderTableSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, EnablePush ? 1u : 0); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_CONCURRENT_STREAMS, MaxConcurrentStreams); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, InitialWindowSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, MaxFrameSize); + yield return new Http2PeerSetting(Http2SettingsParameter.SETTINGS_MAX_HEADER_LIST_SIZE, MaxHeaderListSize); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2PingFrameFlags.cs b/src/Kestrel.Core/Internal/Http2/Http2PingFrameFlags.cs new file mode 100644 index 0000000000..da5163f7e7 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2PingFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2PingFrameFlags : byte + { + NONE = 0x0, + ACK = 0x1 + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2SettingsFrameFlags.cs b/src/Kestrel.Core/Internal/Http2/Http2SettingsFrameFlags.cs new file mode 100644 index 0000000000..5b0b8666cd --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2SettingsFrameFlags.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + [Flags] + public enum Http2SettingsFrameFlags : byte + { + NONE = 0x0, + ACK = 0x1, + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2SettingsParameter.cs b/src/Kestrel.Core/Internal/Http2/Http2SettingsParameter.cs new file mode 100644 index 0000000000..918422a4c2 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2SettingsParameter.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public enum Http2SettingsParameter : ushort + { + SETTINGS_HEADER_TABLE_SIZE = 0x1, + SETTINGS_ENABLE_PUSH = 0x2, + SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, + SETTINGS_INITIAL_WINDOW_SIZE = 0x4, + SETTINGS_MAX_FRAME_SIZE = 0x5, + SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs b/src/Kestrel.Core/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs new file mode 100644 index 0000000000..95db1c9d58 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2SettingsParameterOutOfRangeException.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2SettingsParameterOutOfRangeException : Exception + { + public Http2SettingsParameterOutOfRangeException(Http2SettingsParameter parameter, uint lowerBound, uint upperBound) + : base($"HTTP/2 SETTINGS parameter {parameter} must be set to a value between {lowerBound} and {upperBound}") + { + Parameter = parameter; + } + + public Http2SettingsParameter Parameter { get; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.FeatureCollection.cs new file mode 100644 index 0000000000..2adb8074d7 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.FeatureCollection.cs @@ -0,0 +1,289 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Stream : IFeatureCollection, + IHttpRequestFeature, + IHttpResponseFeature, + IHttpUpgradeFeature, + IHttpConnectionFeature, + IHttpRequestLifetimeFeature, + IHttpRequestIdentifierFeature, + IHttpBodyControlFeature, + IHttpMaxRequestBodySizeFeature, + IHttpMinRequestBodyDataRateFeature, + IHttpMinResponseDataRateFeature, + IHttp2StreamIdFeature + { + // NOTE: When feature interfaces are added to or removed from this Frame class implementation, + // then the list of `implementedFeatures` in the generated code project MUST also be updated. + // See also: tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/FrameFeatureCollection.cs + + private int _featureRevision; + + private List> MaybeExtra; + + public void ResetFeatureCollection() + { + FastReset(); + MaybeExtra?.Clear(); + _featureRevision++; + } + + private object ExtraFeatureGet(Type key) + { + if (MaybeExtra == null) + { + return null; + } + for (var i = 0; i < MaybeExtra.Count; i++) + { + var kv = MaybeExtra[i]; + if (kv.Key == key) + { + return kv.Value; + } + } + return null; + } + + private void ExtraFeatureSet(Type key, object value) + { + if (MaybeExtra == null) + { + MaybeExtra = new List>(2); + } + + for (var i = 0; i < MaybeExtra.Count; i++) + { + if (MaybeExtra[i].Key == key) + { + MaybeExtra[i] = new KeyValuePair(key, value); + return; + } + } + MaybeExtra.Add(new KeyValuePair(key, value)); + } + + string IHttpRequestFeature.Protocol + { + get => HttpVersion; + set => throw new InvalidOperationException(); + } + + string IHttpRequestFeature.Scheme + { + get => Scheme ?? "http"; + set => Scheme = value; + } + + string IHttpRequestFeature.Method + { + get => Method; + set => Method = value; + } + + string IHttpRequestFeature.PathBase + { + get => PathBase ?? ""; + set => PathBase = value; + } + + string IHttpRequestFeature.Path + { + get => Path; + set => Path = value; + } + + string IHttpRequestFeature.QueryString + { + get => QueryString; + set => QueryString = value; + } + + string IHttpRequestFeature.RawTarget + { + get => RawTarget; + set => RawTarget = value; + } + + IHeaderDictionary IHttpRequestFeature.Headers + { + get => RequestHeaders; + set => RequestHeaders = value; + } + + Stream IHttpRequestFeature.Body + { + get => RequestBody; + set => RequestBody = value; + } + + int IHttpResponseFeature.StatusCode + { + get => StatusCode; + set => StatusCode = value; + } + + string IHttpResponseFeature.ReasonPhrase + { + get => ReasonPhrase; + set => ReasonPhrase = value; + } + + IHeaderDictionary IHttpResponseFeature.Headers + { + get => ResponseHeaders; + set => ResponseHeaders = value; + } + + Stream IHttpResponseFeature.Body + { + get => ResponseBody; + set => ResponseBody = value; + } + + CancellationToken IHttpRequestLifetimeFeature.RequestAborted + { + get => RequestAborted; + set => RequestAborted = value; + } + + bool IHttpResponseFeature.HasStarted => HasResponseStarted; + + bool IHttpUpgradeFeature.IsUpgradableRequest => false; + + bool IFeatureCollection.IsReadOnly => false; + + int IFeatureCollection.Revision => _featureRevision; + + IPAddress IHttpConnectionFeature.RemoteIpAddress + { + get => RemoteIpAddress; + set => RemoteIpAddress = value; + } + + IPAddress IHttpConnectionFeature.LocalIpAddress + { + get => LocalIpAddress; + set => LocalIpAddress = value; + } + + int IHttpConnectionFeature.RemotePort + { + get => RemotePort; + set => RemotePort = value; + } + + int IHttpConnectionFeature.LocalPort + { + get => LocalPort; + set => LocalPort = value; + } + + string IHttpConnectionFeature.ConnectionId + { + get => ConnectionIdFeature; + set => ConnectionIdFeature = value; + } + + string IHttpRequestIdentifierFeature.TraceIdentifier + { + get => TraceIdentifier; + set => TraceIdentifier = value; + } + + bool IHttpBodyControlFeature.AllowSynchronousIO + { + get => AllowSynchronousIO; + set => AllowSynchronousIO = value; + } + + bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody; + + long? IHttpMaxRequestBodySizeFeature.MaxRequestBodySize + { + get => MaxRequestBodySize; + set + { + if (HasStartedConsumingRequestBody) + { + throw new InvalidOperationException(CoreStrings.MaxRequestBodySizeCannotBeModifiedAfterRead); + } + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), CoreStrings.NonNegativeNumberOrNullRequired); + } + + MaxRequestBodySize = value; + } + } + + MinDataRate IHttpMinRequestBodyDataRateFeature.MinDataRate + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + MinDataRate IHttpMinResponseDataRateFeature.MinDataRate + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + object IFeatureCollection.this[Type key] + { + get => FastFeatureGet(key); + set => FastFeatureSet(key, value); + } + + TFeature IFeatureCollection.Get() + { + return (TFeature)FastFeatureGet(typeof(TFeature)); + } + + void IFeatureCollection.Set(TFeature instance) + { + FastFeatureSet(typeof(TFeature), instance); + } + + void IHttpResponseFeature.OnStarting(Func callback, object state) + { + OnStarting(callback, state); + } + + void IHttpResponseFeature.OnCompleted(Func callback, object state) + { + OnCompleted(callback, state); + } + + Task IHttpUpgradeFeature.UpgradeAsync() + { + throw new NotImplementedException(); + } + + IEnumerator> IEnumerable>.GetEnumerator() => FastEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => FastEnumerable().GetEnumerator(); + + void IHttpRequestLifetimeFeature.Abort() + { + Abort(error: null); + } + + int IHttp2StreamIdFeature.StreamId => StreamId; + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.Generated.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.Generated.cs new file mode 100644 index 0000000000..4652b01e41 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.Generated.cs @@ -0,0 +1,378 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public partial class Http2Stream + { + private static readonly Type IHttpRequestFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpRequestFeature); + private static readonly Type IHttpResponseFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpResponseFeature); + private static readonly Type IHttpRequestIdentifierFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpRequestIdentifierFeature); + private static readonly Type IServiceProvidersFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IServiceProvidersFeature); + private static readonly Type IHttpRequestLifetimeFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpRequestLifetimeFeature); + private static readonly Type IHttpConnectionFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpConnectionFeature); + private static readonly Type IHttpAuthenticationFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.Authentication.IHttpAuthenticationFeature); + private static readonly Type IQueryFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IQueryFeature); + private static readonly Type IFormFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IFormFeature); + private static readonly Type IHttpUpgradeFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpUpgradeFeature); + private static readonly Type IResponseCookiesFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IResponseCookiesFeature); + private static readonly Type IItemsFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IItemsFeature); + private static readonly Type ITlsConnectionFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.ITlsConnectionFeature); + private static readonly Type IHttpWebSocketFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpWebSocketFeature); + private static readonly Type ISessionFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.ISessionFeature); + private static readonly Type IHttpMaxRequestBodySizeFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpMaxRequestBodySizeFeature); + private static readonly Type IHttpMinRequestBodyDataRateFeatureType = typeof(global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinRequestBodyDataRateFeature); + private static readonly Type IHttpMinResponseDataRateFeatureType = typeof(global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinResponseDataRateFeature); + private static readonly Type IHttpBodyControlFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpBodyControlFeature); + private static readonly Type IHttpSendFileFeatureType = typeof(global::Microsoft.AspNetCore.Http.Features.IHttpSendFileFeature); + private static readonly Type IHttp2StreamIdFeatureType = typeof(global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttp2StreamIdFeature); + + private object _currentIHttpRequestFeature; + private object _currentIHttpResponseFeature; + private object _currentIHttpRequestIdentifierFeature; + private object _currentIServiceProvidersFeature; + private object _currentIHttpRequestLifetimeFeature; + private object _currentIHttpConnectionFeature; + private object _currentIHttpAuthenticationFeature; + private object _currentIQueryFeature; + private object _currentIFormFeature; + private object _currentIHttpUpgradeFeature; + private object _currentIResponseCookiesFeature; + private object _currentIItemsFeature; + private object _currentITlsConnectionFeature; + private object _currentIHttpWebSocketFeature; + private object _currentISessionFeature; + private object _currentIHttpMaxRequestBodySizeFeature; + private object _currentIHttpMinRequestBodyDataRateFeature; + private object _currentIHttpMinResponseDataRateFeature; + private object _currentIHttpBodyControlFeature; + private object _currentIHttpSendFileFeature; + private object _currentIHttp2StreamIdFeature; + + private void FastReset() + { + _currentIHttpRequestFeature = this; + _currentIHttpResponseFeature = this; + _currentIHttpUpgradeFeature = this; + _currentIHttpRequestIdentifierFeature = this; + _currentIHttpRequestLifetimeFeature = this; + _currentIHttpConnectionFeature = this; + _currentIHttpMaxRequestBodySizeFeature = this; + _currentIHttpMinRequestBodyDataRateFeature = this; + _currentIHttpMinResponseDataRateFeature = this; + _currentIHttpBodyControlFeature = this; + _currentIHttp2StreamIdFeature = this; + + _currentIServiceProvidersFeature = null; + _currentIHttpAuthenticationFeature = null; + _currentIQueryFeature = null; + _currentIFormFeature = null; + _currentIResponseCookiesFeature = null; + _currentIItemsFeature = null; + _currentITlsConnectionFeature = null; + _currentIHttpWebSocketFeature = null; + _currentISessionFeature = null; + _currentIHttpSendFileFeature = null; + } + + internal object FastFeatureGet(Type key) + { + if (key == IHttpRequestFeatureType) + { + return _currentIHttpRequestFeature; + } + if (key == IHttpResponseFeatureType) + { + return _currentIHttpResponseFeature; + } + if (key == IHttpRequestIdentifierFeatureType) + { + return _currentIHttpRequestIdentifierFeature; + } + if (key == IServiceProvidersFeatureType) + { + return _currentIServiceProvidersFeature; + } + if (key == IHttpRequestLifetimeFeatureType) + { + return _currentIHttpRequestLifetimeFeature; + } + if (key == IHttpConnectionFeatureType) + { + return _currentIHttpConnectionFeature; + } + if (key == IHttpAuthenticationFeatureType) + { + return _currentIHttpAuthenticationFeature; + } + if (key == IQueryFeatureType) + { + return _currentIQueryFeature; + } + if (key == IFormFeatureType) + { + return _currentIFormFeature; + } + if (key == IHttpUpgradeFeatureType) + { + return _currentIHttpUpgradeFeature; + } + if (key == IResponseCookiesFeatureType) + { + return _currentIResponseCookiesFeature; + } + if (key == IItemsFeatureType) + { + return _currentIItemsFeature; + } + if (key == ITlsConnectionFeatureType) + { + return _currentITlsConnectionFeature; + } + if (key == IHttpWebSocketFeatureType) + { + return _currentIHttpWebSocketFeature; + } + if (key == ISessionFeatureType) + { + return _currentISessionFeature; + } + if (key == IHttpMaxRequestBodySizeFeatureType) + { + return _currentIHttpMaxRequestBodySizeFeature; + } + if (key == IHttpMinRequestBodyDataRateFeatureType) + { + return _currentIHttpMinRequestBodyDataRateFeature; + } + if (key == IHttpMinResponseDataRateFeatureType) + { + return _currentIHttpMinResponseDataRateFeature; + } + if (key == IHttpBodyControlFeatureType) + { + return _currentIHttpBodyControlFeature; + } + if (key == IHttpSendFileFeatureType) + { + return _currentIHttpSendFileFeature; + } + if (key == IHttp2StreamIdFeatureType) + { + return _currentIHttp2StreamIdFeature; + } + return ExtraFeatureGet(key); + } + + internal void FastFeatureSet(Type key, object feature) + { + _featureRevision++; + + if (key == IHttpRequestFeatureType) + { + _currentIHttpRequestFeature = feature; + return; + } + if (key == IHttpResponseFeatureType) + { + _currentIHttpResponseFeature = feature; + return; + } + if (key == IHttpRequestIdentifierFeatureType) + { + _currentIHttpRequestIdentifierFeature = feature; + return; + } + if (key == IServiceProvidersFeatureType) + { + _currentIServiceProvidersFeature = feature; + return; + } + if (key == IHttpRequestLifetimeFeatureType) + { + _currentIHttpRequestLifetimeFeature = feature; + return; + } + if (key == IHttpConnectionFeatureType) + { + _currentIHttpConnectionFeature = feature; + return; + } + if (key == IHttpAuthenticationFeatureType) + { + _currentIHttpAuthenticationFeature = feature; + return; + } + if (key == IQueryFeatureType) + { + _currentIQueryFeature = feature; + return; + } + if (key == IFormFeatureType) + { + _currentIFormFeature = feature; + return; + } + if (key == IHttpUpgradeFeatureType) + { + _currentIHttpUpgradeFeature = feature; + return; + } + if (key == IResponseCookiesFeatureType) + { + _currentIResponseCookiesFeature = feature; + return; + } + if (key == IItemsFeatureType) + { + _currentIItemsFeature = feature; + return; + } + if (key == ITlsConnectionFeatureType) + { + _currentITlsConnectionFeature = feature; + return; + } + if (key == IHttpWebSocketFeatureType) + { + _currentIHttpWebSocketFeature = feature; + return; + } + if (key == ISessionFeatureType) + { + _currentISessionFeature = feature; + return; + } + if (key == IHttpMaxRequestBodySizeFeatureType) + { + _currentIHttpMaxRequestBodySizeFeature = feature; + return; + } + if (key == IHttpMinRequestBodyDataRateFeatureType) + { + _currentIHttpMinRequestBodyDataRateFeature = feature; + return; + } + if (key == IHttpMinResponseDataRateFeatureType) + { + _currentIHttpMinResponseDataRateFeature = feature; + return; + } + if (key == IHttpBodyControlFeatureType) + { + _currentIHttpBodyControlFeature = feature; + return; + } + if (key == IHttpSendFileFeatureType) + { + _currentIHttpSendFileFeature = feature; + return; + } + if (key == IHttp2StreamIdFeatureType) + { + _currentIHttp2StreamIdFeature = feature; + return; + }; + ExtraFeatureSet(key, feature); + } + + private IEnumerable> FastEnumerable() + { + if (_currentIHttpRequestFeature != null) + { + yield return new KeyValuePair(IHttpRequestFeatureType, _currentIHttpRequestFeature as global::Microsoft.AspNetCore.Http.Features.IHttpRequestFeature); + } + if (_currentIHttpResponseFeature != null) + { + yield return new KeyValuePair(IHttpResponseFeatureType, _currentIHttpResponseFeature as global::Microsoft.AspNetCore.Http.Features.IHttpResponseFeature); + } + if (_currentIHttpRequestIdentifierFeature != null) + { + yield return new KeyValuePair(IHttpRequestIdentifierFeatureType, _currentIHttpRequestIdentifierFeature as global::Microsoft.AspNetCore.Http.Features.IHttpRequestIdentifierFeature); + } + if (_currentIServiceProvidersFeature != null) + { + yield return new KeyValuePair(IServiceProvidersFeatureType, _currentIServiceProvidersFeature as global::Microsoft.AspNetCore.Http.Features.IServiceProvidersFeature); + } + if (_currentIHttpRequestLifetimeFeature != null) + { + yield return new KeyValuePair(IHttpRequestLifetimeFeatureType, _currentIHttpRequestLifetimeFeature as global::Microsoft.AspNetCore.Http.Features.IHttpRequestLifetimeFeature); + } + if (_currentIHttpConnectionFeature != null) + { + yield return new KeyValuePair(IHttpConnectionFeatureType, _currentIHttpConnectionFeature as global::Microsoft.AspNetCore.Http.Features.IHttpConnectionFeature); + } + if (_currentIHttpAuthenticationFeature != null) + { + yield return new KeyValuePair(IHttpAuthenticationFeatureType, _currentIHttpAuthenticationFeature as global::Microsoft.AspNetCore.Http.Features.Authentication.IHttpAuthenticationFeature); + } + if (_currentIQueryFeature != null) + { + yield return new KeyValuePair(IQueryFeatureType, _currentIQueryFeature as global::Microsoft.AspNetCore.Http.Features.IQueryFeature); + } + if (_currentIFormFeature != null) + { + yield return new KeyValuePair(IFormFeatureType, _currentIFormFeature as global::Microsoft.AspNetCore.Http.Features.IFormFeature); + } + if (_currentIHttpUpgradeFeature != null) + { + yield return new KeyValuePair(IHttpUpgradeFeatureType, _currentIHttpUpgradeFeature as global::Microsoft.AspNetCore.Http.Features.IHttpUpgradeFeature); + } + if (_currentIResponseCookiesFeature != null) + { + yield return new KeyValuePair(IResponseCookiesFeatureType, _currentIResponseCookiesFeature as global::Microsoft.AspNetCore.Http.Features.IResponseCookiesFeature); + } + if (_currentIItemsFeature != null) + { + yield return new KeyValuePair(IItemsFeatureType, _currentIItemsFeature as global::Microsoft.AspNetCore.Http.Features.IItemsFeature); + } + if (_currentITlsConnectionFeature != null) + { + yield return new KeyValuePair(ITlsConnectionFeatureType, _currentITlsConnectionFeature as global::Microsoft.AspNetCore.Http.Features.ITlsConnectionFeature); + } + if (_currentIHttpWebSocketFeature != null) + { + yield return new KeyValuePair(IHttpWebSocketFeatureType, _currentIHttpWebSocketFeature as global::Microsoft.AspNetCore.Http.Features.IHttpWebSocketFeature); + } + if (_currentISessionFeature != null) + { + yield return new KeyValuePair(ISessionFeatureType, _currentISessionFeature as global::Microsoft.AspNetCore.Http.Features.ISessionFeature); + } + if (_currentIHttpMaxRequestBodySizeFeature != null) + { + yield return new KeyValuePair(IHttpMaxRequestBodySizeFeatureType, _currentIHttpMaxRequestBodySizeFeature as global::Microsoft.AspNetCore.Http.Features.IHttpMaxRequestBodySizeFeature); + } + if (_currentIHttpMinRequestBodyDataRateFeature != null) + { + yield return new KeyValuePair(IHttpMinRequestBodyDataRateFeatureType, _currentIHttpMinRequestBodyDataRateFeature as global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinRequestBodyDataRateFeature); + } + if (_currentIHttpMinResponseDataRateFeature != null) + { + yield return new KeyValuePair(IHttpMinResponseDataRateFeatureType, _currentIHttpMinResponseDataRateFeature as global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttpMinResponseDataRateFeature); + } + if (_currentIHttpBodyControlFeature != null) + { + yield return new KeyValuePair(IHttpBodyControlFeatureType, _currentIHttpBodyControlFeature as global::Microsoft.AspNetCore.Http.Features.IHttpBodyControlFeature); + } + if (_currentIHttpSendFileFeature != null) + { + yield return new KeyValuePair(IHttpSendFileFeatureType, _currentIHttpSendFileFeature as global::Microsoft.AspNetCore.Http.Features.IHttpSendFileFeature); + } + if (_currentIHttp2StreamIdFeature != null) + { + yield return new KeyValuePair(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature as global::Microsoft.AspNetCore.Server.Kestrel.Core.Features.IHttp2StreamIdFeature); + } + + if (MaybeExtra != null) + { + foreach(var item in MaybeExtra) + { + yield return item; + } + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs new file mode 100644 index 0000000000..0415cb3449 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs @@ -0,0 +1,1031 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Net; +using System.Text; +using System.Text.Encodings.Web.Utf8; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +// ReSharper disable AccessToModifiedClosure + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public abstract partial class Http2Stream : IFrameControl + { + private const byte ByteAsterisk = (byte)'*'; + private const byte ByteForwardSlash = (byte)'/'; + private const byte BytePercentage = (byte)'%'; + + private static readonly byte[] _bytesServer = Encoding.ASCII.GetBytes("\r\nServer: " + Constants.ServerName); + + private const string EmptyPath = "/"; + private const string Asterisk = "*"; + + private readonly object _onStartingSync = new Object(); + private readonly object _onCompletedSync = new Object(); + + private Http2StreamContext _context; + private Http2Streams _streams; + + protected Stack, object>> _onStarting; + protected Stack, object>> _onCompleted; + + protected int _requestAborted; + private CancellationTokenSource _abortedCts; + private CancellationToken? _manuallySetRequestAbortToken; + + protected RequestProcessingStatus _requestProcessingStatus; + private bool _canHaveBody; + protected Exception _applicationException; + private BadHttpRequestException _requestRejectedException; + + private string _requestId; + + protected long _responseBytesWritten; + + private HttpRequestTarget _requestTargetForm = HttpRequestTarget.Unknown; + private Uri _absoluteRequestTarget; + + public Http2Stream(Http2StreamContext context) + { + _context = context; + HttpStreamControl = this; + ServerOptions = context.ServiceContext.ServerOptions; + RequestBodyPipe = CreateRequestBodyPipe(); + } + + public IFrameControl HttpStreamControl { get; set; } + + public Http2MessageBody MessageBody { get; protected set; } + public IPipe RequestBodyPipe { get; } + + protected string ConnectionId => _context.ConnectionId; + public int StreamId => _context.StreamId; + public ServiceContext ServiceContext => _context.ServiceContext; + + // Hold direct reference to ServerOptions since this is used very often in the request processing path + private KestrelServerOptions ServerOptions { get; } + + public IFeatureCollection ConnectionFeatures { get; set; } + protected IHttp2StreamLifetimeHandler StreamLifetimeHandler => _context.StreamLifetimeHandler; + public IHttp2FrameWriter Output => _context.FrameWriter; + + protected IKestrelTrace Log => ServiceContext.Log; + private DateHeaderValueManager DateHeaderValueManager => ServiceContext.DateHeaderValueManager; + + private IPEndPoint LocalEndPoint => _context.LocalEndPoint; + private IPEndPoint RemoteEndPoint => _context.RemoteEndPoint; + + public string ConnectionIdFeature { get; set; } + public bool HasStartedConsumingRequestBody { get; set; } + public long? MaxRequestBodySize { get; set; } + public bool AllowSynchronousIO { get; set; } + + public bool ExpectBody { get; set; } + + /// + /// The request id. + /// + public string TraceIdentifier + { + set => _requestId = value; + get + { + // don't generate an ID until it is requested + if (_requestId == null) + { + _requestId = StringUtilities.ConcatAsHexSuffix(ConnectionId, ':', (uint)StreamId); + } + return _requestId; + } + } + + public IPAddress RemoteIpAddress { get; set; } + public int RemotePort { get; set; } + public IPAddress LocalIpAddress { get; set; } + public int LocalPort { get; set; } + public string Scheme { get; set; } + public string Method { get; set; } + public string PathBase { get; set; } + public string Path { get; set; } + public string QueryString { get; set; } + public string RawTarget { get; set; } + + public string HttpVersion => "HTTP/2"; + + public IHeaderDictionary RequestHeaders { get; set; } + public Stream RequestBody { get; set; } + + private int _statusCode; + public int StatusCode + { + get => _statusCode; + set + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(StatusCode)); + } + + _statusCode = value; + } + } + + private string _reasonPhrase; + + public string ReasonPhrase + { + get => _reasonPhrase; + + set + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(ReasonPhrase)); + } + + _reasonPhrase = value; + } + } + + public IHeaderDictionary ResponseHeaders { get; set; } + public Stream ResponseBody { get; set; } + + public CancellationToken RequestAborted + { + get + { + // If a request abort token was previously explicitly set, return it. + if (_manuallySetRequestAbortToken.HasValue) + { + return _manuallySetRequestAbortToken.Value; + } + // Otherwise, get the abort CTS. If we have one, which would mean that someone previously + // asked for the RequestAborted token, simply return its token. If we don't, + // check to see whether we've already aborted, in which case just return an + // already canceled token. Finally, force a source into existence if we still + // don't have one, and return its token. + var cts = _abortedCts; + return + cts != null ? cts.Token : + (Volatile.Read(ref _requestAborted) == 1) ? new CancellationToken(true) : + RequestAbortedSource.Token; + } + set + { + // Set an abort token, overriding one we create internally. This setter and associated + // field exist purely to support IHttpRequestLifetimeFeature.set_RequestAborted. + _manuallySetRequestAbortToken = value; + } + } + + private CancellationTokenSource RequestAbortedSource + { + get + { + // Get the abort token, lazily-initializing it if necessary. + // Make sure it's canceled if an abort request already came in. + + // EnsureInitialized can return null since _abortedCts is reset to null + // after it's already been initialized to a non-null value. + // If EnsureInitialized does return null, this property was accessed between + // requests so it's safe to return an ephemeral CancellationTokenSource. + var cts = LazyInitializer.EnsureInitialized(ref _abortedCts, () => new CancellationTokenSource()) + ?? new CancellationTokenSource(); + + if (Volatile.Read(ref _requestAborted) == 1) + { + cts.Cancel(); + } + return cts; + } + } + + public bool HasResponseStarted => _requestProcessingStatus == RequestProcessingStatus.ResponseStarted; + + protected FrameRequestHeaders FrameRequestHeaders { get; } = new FrameRequestHeaders(); + + protected FrameResponseHeaders FrameResponseHeaders { get; } = new FrameResponseHeaders(); + + public void InitializeStreams(Http2MessageBody messageBody) + { + if (_streams == null) + { + _streams = new Http2Streams(bodyControl: this, httpStreamControl: this); + } + + (RequestBody, ResponseBody) = _streams.Start(messageBody); + } + + public void PauseStreams() => _streams.Pause(); + + public void StopStreams() => _streams.Stop(); + + public void Reset() + { + _onStarting = null; + _onCompleted = null; + + _requestProcessingStatus = RequestProcessingStatus.RequestPending; + _applicationException = null; + + ResetFeatureCollection(); + + HasStartedConsumingRequestBody = false; + MaxRequestBodySize = ServerOptions.Limits.MaxRequestBodySize; + AllowSynchronousIO = ServerOptions.AllowSynchronousIO; + TraceIdentifier = null; + Scheme = null; + Method = null; + PathBase = null; + Path = null; + RawTarget = null; + _requestTargetForm = HttpRequestTarget.Unknown; + _absoluteRequestTarget = null; + QueryString = null; + StatusCode = StatusCodes.Status200OK; + ReasonPhrase = null; + + RemoteIpAddress = RemoteEndPoint?.Address; + RemotePort = RemoteEndPoint?.Port ?? 0; + + LocalIpAddress = LocalEndPoint?.Address; + LocalPort = LocalEndPoint?.Port ?? 0; + ConnectionIdFeature = ConnectionId; + + FrameRequestHeaders.Reset(); + FrameResponseHeaders.Reset(); + RequestHeaders = FrameRequestHeaders; + ResponseHeaders = FrameResponseHeaders; + + if (ConnectionFeatures != null) + { + foreach (var feature in ConnectionFeatures) + { + // Set the scheme to https if there's an ITlsConnectionFeature + if (feature.Key == typeof(ITlsConnectionFeature)) + { + Scheme = "https"; + } + + FastFeatureSet(feature.Key, feature.Value); + } + } + + _manuallySetRequestAbortToken = null; + _abortedCts = null; + + _responseBytesWritten = 0; + } + + private void CancelRequestAbortedToken() + { + try + { + RequestAbortedSource.Cancel(); + _abortedCts = null; + } + catch (Exception ex) + { + Log.ApplicationError(ConnectionId, TraceIdentifier, ex); + } + } + + public void Abort(Exception error) + { + if (Interlocked.Exchange(ref _requestAborted, 1) == 0) + { + _streams?.Abort(error); + + // Potentially calling user code. CancelRequestAbortedToken logs any exceptions. + ServiceContext.ThreadPool.UnsafeRun(state => ((Http2Stream)state).CancelRequestAbortedToken(), this); + } + } + + public abstract Task ProcessRequestAsync(); + + public void OnStarting(Func callback, object state) + { + lock (_onStartingSync) + { + if (HasResponseStarted) + { + ThrowResponseAlreadyStartedException(nameof(OnStarting)); + } + + if (_onStarting == null) + { + _onStarting = new Stack, object>>(); + } + _onStarting.Push(new KeyValuePair, object>(callback, state)); + } + } + + public void OnCompleted(Func callback, object state) + { + lock (_onCompletedSync) + { + if (_onCompleted == null) + { + _onCompleted = new Stack, object>>(); + } + _onCompleted.Push(new KeyValuePair, object>(callback, state)); + } + } + + protected async Task FireOnStarting() + { + Stack, object>> onStarting = null; + lock (_onStartingSync) + { + onStarting = _onStarting; + _onStarting = null; + } + if (onStarting != null) + { + try + { + foreach (var entry in onStarting) + { + await entry.Key.Invoke(entry.Value); + } + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + } + } + + protected async Task FireOnCompleted() + { + Stack, object>> onCompleted = null; + lock (_onCompletedSync) + { + onCompleted = _onCompleted; + _onCompleted = null; + } + if (onCompleted != null) + { + foreach (var entry in onCompleted) + { + try + { + await entry.Key.Invoke(entry.Value); + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + } + } + } + + public async Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + await InitializeResponse(0); + await Output.FlushAsync(cancellationToken); + } + + public Task WriteAsync(ArraySegment data, CancellationToken cancellationToken = default(CancellationToken)) + { + if (!HasResponseStarted) + { + return WriteAsyncAwaited(data, cancellationToken); + } + + VerifyAndUpdateWrite(data.Count); + + if (_canHaveBody) + { + CheckLastWrite(); + return Output.WriteDataAsync(StreamId, data, cancellationToken: cancellationToken); + } + else + { + HandleNonBodyResponseWrite(); + return Task.CompletedTask; + } + } + + public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken cancellationToken) + { + await InitializeResponseAwaited(data.Count); + + // WriteAsyncAwaited is only called for the first write to the body. + // Ensure headers are flushed if Write(Chunked)Async isn't called. + if (_canHaveBody) + { + CheckLastWrite(); + await Output.WriteDataAsync(StreamId, data, cancellationToken: cancellationToken); + } + else + { + HandleNonBodyResponseWrite(); + await FlushAsync(cancellationToken); + } + } + + private void VerifyAndUpdateWrite(int count) + { + var responseHeaders = FrameResponseHeaders; + + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten + count > responseHeaders.ContentLength.Value) + { + throw new InvalidOperationException( + CoreStrings.FormatTooManyBytesWritten(_responseBytesWritten + count, responseHeaders.ContentLength.Value)); + } + + _responseBytesWritten += count; + } + + private void CheckLastWrite() + { + var responseHeaders = FrameResponseHeaders; + + // Prevent firing request aborted token if this is the last write, to avoid + // aborting the request if the app is still running when the client receives + // the final bytes of the response and gracefully closes the connection. + // + // Called after VerifyAndUpdateWrite(), so _responseBytesWritten has already been updated. + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten == responseHeaders.ContentLength.Value) + { + _abortedCts = null; + } + } + + protected void VerifyResponseContentLength() + { + var responseHeaders = FrameResponseHeaders; + + if (!HttpMethods.IsHead(Method) && + !responseHeaders.HasTransferEncoding && + responseHeaders.ContentLength.HasValue && + _responseBytesWritten < responseHeaders.ContentLength.Value) + { + // We need to close the connection if any bytes were written since the client + // cannot be certain of how many bytes it will receive. + if (_responseBytesWritten > 0) + { + // TODO: HTTP/2 + } + + ReportApplicationError(new InvalidOperationException( + CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value))); + } + } + + private static ArraySegment CreateAsciiByteArraySegment(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + return new ArraySegment(bytes); + } + + public void ProduceContinue() + { + if (HasResponseStarted) + { + return; + } + + if (RequestHeaders.TryGetValue("Expect", out var expect) && + (expect.FirstOrDefault() ?? "").Equals("100-continue", StringComparison.OrdinalIgnoreCase)) + { + Output.Write100ContinueAsync(StreamId).GetAwaiter().GetResult(); + } + } + + public Task InitializeResponse(int firstWriteByteCount) + { + if (HasResponseStarted) + { + return Task.CompletedTask; + } + + if (_onStarting != null) + { + return InitializeResponseAwaited(firstWriteByteCount); + } + + if (_applicationException != null) + { + ThrowResponseAbortedException(); + } + + VerifyAndUpdateWrite(firstWriteByteCount); + + return ProduceStart(appCompleted: false); + } + + private async Task InitializeResponseAwaited(int firstWriteByteCount) + { + await FireOnStarting(); + + if (_applicationException != null) + { + ThrowResponseAbortedException(); + } + + VerifyAndUpdateWrite(firstWriteByteCount); + + await ProduceStart(appCompleted: false); + } + + private Task ProduceStart(bool appCompleted) + { + if (HasResponseStarted) + { + return Task.CompletedTask; + } + + _requestProcessingStatus = RequestProcessingStatus.ResponseStarted; + + return CreateResponseHeader(appCompleted); + } + + protected Task TryProduceInvalidRequestResponse() + { + if (_requestRejectedException != null) + { + return ProduceEnd(); + } + + return Task.CompletedTask; + } + + protected Task ProduceEnd() + { + if (_requestRejectedException != null || _applicationException != null) + { + if (HasResponseStarted) + { + // We can no longer change the response, so we simply close the connection. + return Task.CompletedTask; + } + + // If the request was rejected, the error state has already been set by SetBadRequestState and + // that should take precedence. + if (_requestRejectedException != null) + { + SetErrorResponseException(_requestRejectedException); + } + else + { + // 500 Internal Server Error + SetErrorResponseHeaders(statusCode: StatusCodes.Status500InternalServerError); + } + } + + if (!HasResponseStarted) + { + return ProduceEndAwaited(); + } + + return WriteSuffix(); + } + + private async Task ProduceEndAwaited() + { + await ProduceStart(appCompleted: true); + + // Force flush + await Output.FlushAsync(default(CancellationToken)); + + await WriteSuffix(); + } + + private Task WriteSuffix() + { + if (HttpMethods.IsHead(Method) && _responseBytesWritten > 0) + { + Log.ConnectionHeadResponseBodyWrite(ConnectionId, _responseBytesWritten); + } + + return Output.WriteDataAsync(StreamId, Span.Empty, endStream: true, cancellationToken: default(CancellationToken)); + } + + private Task CreateResponseHeader(bool appCompleted) + { + var responseHeaders = FrameResponseHeaders; + var hasConnection = responseHeaders.HasConnection; + var connectionOptions = FrameHeaders.ParseConnection(responseHeaders.HeaderConnection); + var hasTransferEncoding = responseHeaders.HasTransferEncoding; + var transferCoding = FrameHeaders.GetFinalTransferCoding(responseHeaders.HeaderTransferEncoding); + + // https://tools.ietf.org/html/rfc7230#section-3.3.1 + // If any transfer coding other than + // chunked is applied to a response payload body, the sender MUST either + // apply chunked as the final transfer coding or terminate the message + // by closing the connection. + if (hasTransferEncoding && transferCoding == TransferCoding.Chunked) + { + // TODO: this is an error in HTTP/2 + } + + // Set whether response can have body + _canHaveBody = StatusCanHaveBody(StatusCode) && Method != "HEAD"; + + // Don't set the Content-Length or Transfer-Encoding headers + // automatically for HEAD requests or 204, 205, 304 responses. + if (_canHaveBody) + { + if (appCompleted) + { + // Since the app has completed and we are only now generating + // the headers we can safely set the Content-Length to 0. + responseHeaders.ContentLength = 0; + } + } + else if (hasTransferEncoding) + { + RejectNonBodyTransferEncodingResponse(appCompleted); + } + + responseHeaders.SetReadOnly(); + + if (ServerOptions.AddServerHeader && !responseHeaders.HasServer) + { + responseHeaders.SetRawServer(Constants.ServerName, _bytesServer); + } + + if (!responseHeaders.HasDate) + { + var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues(); + responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + } + + return Output.WriteHeadersAsync(StreamId, StatusCode, responseHeaders); + } + + public bool StatusCanHaveBody(int statusCode) + { + // List of status codes taken from Microsoft.Net.Http.Server.Response + return statusCode != StatusCodes.Status204NoContent && + statusCode != StatusCodes.Status205ResetContent && + statusCode != StatusCodes.Status304NotModified; + } + + private void ThrowResponseAlreadyStartedException(string value) + { + throw new InvalidOperationException(CoreStrings.FormatParameterReadOnlyAfterResponseStarted(value)); + } + + private void RejectNonBodyTransferEncodingResponse(bool appCompleted) + { + var ex = new InvalidOperationException(CoreStrings.FormatHeaderNotAllowedOnResponse("Transfer-Encoding", StatusCode)); + if (!appCompleted) + { + // Back out of header creation surface exeception in user code + _requestProcessingStatus = RequestProcessingStatus.AppStarted; + throw ex; + } + else + { + ReportApplicationError(ex); + + // 500 Internal Server Error + SetErrorResponseHeaders(statusCode: StatusCodes.Status500InternalServerError); + } + } + + private void SetErrorResponseException(BadHttpRequestException ex) + { + SetErrorResponseHeaders(ex.StatusCode); + + if (!StringValues.IsNullOrEmpty(ex.AllowedHeader)) + { + FrameResponseHeaders.HeaderAllow = ex.AllowedHeader; + } + } + + private void SetErrorResponseHeaders(int statusCode) + { + Debug.Assert(!HasResponseStarted, $"{nameof(SetErrorResponseHeaders)} called after response had already started."); + + StatusCode = statusCode; + ReasonPhrase = null; + + var responseHeaders = FrameResponseHeaders; + responseHeaders.Reset(); + var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues(); + + responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); + + responseHeaders.ContentLength = 0; + + if (ServerOptions.AddServerHeader) + { + responseHeaders.SetRawServer(Constants.ServerName, _bytesServer); + } + } + + public void HandleNonBodyResponseWrite() + { + // Writes to HEAD response are ignored and logged at the end of the request + if (Method != "HEAD") + { + // Throw Exception for 204, 205, 304 responses. + throw new InvalidOperationException(CoreStrings.FormatWritingToResponseBodyNotSupported(StatusCode)); + } + } + + private void ThrowResponseAbortedException() + { + throw new ObjectDisposedException(CoreStrings.UnhandledApplicationException, _applicationException); + } + + public void ThrowRequestRejected(RequestRejectionReason reason) + => throw BadHttpRequestException.GetException(reason); + + public void ThrowRequestRejected(RequestRejectionReason reason, string detail) + => throw BadHttpRequestException.GetException(reason, detail); + + private void ThrowRequestTargetRejected(Span target) + => throw GetInvalidRequestTargetException(target); + + private BadHttpRequestException GetInvalidRequestTargetException(Span target) + => BadHttpRequestException.GetException( + RequestRejectionReason.InvalidRequestTarget, + Log.IsEnabled(LogLevel.Information) + ? target.GetAsciiStringEscaped(Constants.MaxExceptionDetailSize) + : string.Empty); + + public void SetBadRequestState(RequestRejectionReason reason) + { + SetBadRequestState(BadHttpRequestException.GetException(reason)); + } + + public void SetBadRequestState(BadHttpRequestException ex) + { + Log.ConnectionBadRequest(ConnectionId, ex); + + if (!HasResponseStarted) + { + SetErrorResponseException(ex); + } + + _requestRejectedException = ex; + } + + protected void ReportApplicationError(Exception ex) + { + if (_applicationException == null) + { + _applicationException = ex; + } + else if (_applicationException is AggregateException) + { + _applicationException = new AggregateException(_applicationException, ex).Flatten(); + } + else + { + _applicationException = new AggregateException(_applicationException, ex); + } + + Log.ApplicationError(ConnectionId, TraceIdentifier, ex); + } + + private void OnOriginFormTarget(HttpMethod method, Http.HttpVersion version, Span target, Span path, Span query, Span customMethod, bool pathEncoded) + { + Debug.Assert(target[0] == ByteForwardSlash, "Should only be called when path starts with /"); + + _requestTargetForm = HttpRequestTarget.OriginForm; + + // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 + // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; + // then encoded/escaped to ASCII https://www.ietf.org/rfc/rfc3987.txt "Mapping of IRIs to URIs" + string requestUrlPath = null; + string rawTarget = null; + + try + { + // Read raw target before mutating memory. + rawTarget = target.GetAsciiStringNonNullCharacters(); + + if (pathEncoded) + { + // URI was encoded, unescape and then parse as UTF-8 + var pathLength = UrlEncoder.Decode(path, path); + + // Removing dot segments must be done after unescaping. From RFC 3986: + // + // URI producing applications should percent-encode data octets that + // correspond to characters in the reserved set unless these characters + // are specifically allowed by the URI scheme to represent data in that + // component. If a reserved character is found in a URI component and + // no delimiting role is known for that character, then it must be + // interpreted as representing the data octet corresponding to that + // character's encoding in US-ASCII. + // + // https://tools.ietf.org/html/rfc3986#section-2.2 + pathLength = PathNormalizer.RemoveDotSegments(path.Slice(0, pathLength)); + + requestUrlPath = GetUtf8String(path.Slice(0, pathLength)); + } + else + { + var pathLength = PathNormalizer.RemoveDotSegments(path); + + if (path.Length == pathLength && query.Length == 0) + { + // If no decoding was required, no dot segments were removed and + // there is no query, the request path is the same as the raw target + requestUrlPath = rawTarget; + } + else + { + requestUrlPath = path.Slice(0, pathLength).GetAsciiStringNonNullCharacters(); + } + } + } + catch (InvalidOperationException) + { + ThrowRequestTargetRejected(target); + } + + QueryString = query.GetAsciiStringNonNullCharacters(); + RawTarget = rawTarget; + Path = requestUrlPath; + } + + private void OnAuthorityFormTarget(HttpMethod method, Span target) + { + _requestTargetForm = HttpRequestTarget.AuthorityForm; + + // This is not complete validation. It is just a quick scan for invalid characters + // but doesn't check that the target fully matches the URI spec. + for (var i = 0; i < target.Length; i++) + { + var ch = target[i]; + if (!UriUtilities.IsValidAuthorityCharacter(ch)) + { + ThrowRequestTargetRejected(target); + } + } + + // The authority-form of request-target is only used for CONNECT + // requests (https://tools.ietf.org/html/rfc7231#section-4.3.6). + if (method != HttpMethod.Connect) + { + ThrowRequestRejected(RequestRejectionReason.ConnectMethodRequired); + } + + // When making a CONNECT request to establish a tunnel through one or + // more proxies, a client MUST send only the target URI's authority + // component (excluding any userinfo and its "@" delimiter) as the + // request-target.For example, + // + // CONNECT www.example.com:80 HTTP/1.1 + // + // Allowed characters in the 'host + port' section of authority. + // See https://tools.ietf.org/html/rfc3986#section-3.2 + RawTarget = target.GetAsciiStringNonNullCharacters(); + Path = string.Empty; + QueryString = string.Empty; + } + + private void OnAsteriskFormTarget(HttpMethod method) + { + _requestTargetForm = HttpRequestTarget.AsteriskForm; + + // The asterisk-form of request-target is only used for a server-wide + // OPTIONS request (https://tools.ietf.org/html/rfc7231#section-4.3.7). + if (method != HttpMethod.Options) + { + ThrowRequestRejected(RequestRejectionReason.OptionsMethodRequired); + } + + RawTarget = Asterisk; + Path = string.Empty; + QueryString = string.Empty; + } + + private void OnAbsoluteFormTarget(Span target, Span query) + { + _requestTargetForm = HttpRequestTarget.AbsoluteForm; + + // absolute-form + // https://tools.ietf.org/html/rfc7230#section-5.3.2 + + // This code should be the edge-case. + + // From the spec: + // a server MUST accept the absolute-form in requests, even though + // HTTP/1.1 clients will only send them in requests to proxies. + + RawTarget = target.GetAsciiStringNonNullCharacters(); + + // Validation of absolute URIs is slow, but clients + // should not be sending this form anyways, so perf optimization + // not high priority + + if (!Uri.TryCreate(RawTarget, UriKind.Absolute, out var uri)) + { + ThrowRequestTargetRejected(target); + } + + _absoluteRequestTarget = uri; + Path = uri.LocalPath; + // don't use uri.Query because we need the unescaped version + QueryString = query.GetAsciiStringNonNullCharacters(); + } + + private unsafe static string GetUtf8String(Span path) + { + // .NET 451 doesn't have pointer overloads for Encoding.GetString so we + // copy to an array + fixed (byte* pointer = &path.DangerousGetPinnableReference()) + { + return Encoding.UTF8.GetString(pointer, path.Length); + } + } + + public void OnHeader(Span name, Span value) + { + // TODO: move validation of header count and size to HPACK decoding + var valueString = value.GetAsciiStringNonNullCharacters(); + + FrameRequestHeaders.Append(name, valueString); + } + + protected void EnsureHostHeaderExists() + { + // https://tools.ietf.org/html/rfc7230#section-5.4 + // A server MUST respond with a 400 (Bad Request) status code to any + // HTTP/1.1 request message that lacks a Host header field and to any + // request message that contains more than one Host header field or a + // Host header field with an invalid field-value. + + var host = FrameRequestHeaders.HeaderHost; + if (host.Count <= 0) + { + ThrowRequestRejected(RequestRejectionReason.MissingHostHeader); + } + else if (host.Count > 1) + { + ThrowRequestRejected(RequestRejectionReason.MultipleHostHeaders); + } + else if (_requestTargetForm == HttpRequestTarget.AuthorityForm) + { + if (!host.Equals(RawTarget)) + { + ThrowRequestRejected(RequestRejectionReason.InvalidHostHeader, host.ToString()); + } + } + else if (_requestTargetForm == HttpRequestTarget.AbsoluteForm) + { + // If the target URI includes an authority component, then a + // client MUST send a field - value for Host that is identical to that + // authority component, excluding any userinfo subcomponent and its "@" + // delimiter. + + // System.Uri doesn't not tell us if the port was in the original string or not. + // When IsDefaultPort = true, we will allow Host: with or without the default port + var authorityAndPort = _absoluteRequestTarget.Authority + ":" + _absoluteRequestTarget.Port; + if ((host != _absoluteRequestTarget.Authority || !_absoluteRequestTarget.IsDefaultPort) + && host != authorityAndPort) + { + ThrowRequestRejected(RequestRejectionReason.InvalidHostHeader, host.ToString()); + } + } + } + + private IPipe CreateRequestBodyPipe() + => _context.PipeFactory.Create(new PipeOptions + { + ReaderScheduler = ServiceContext.ThreadPool, + WriterScheduler = InlineScheduler.Default, + MaximumSizeHigh = 1, + MaximumSizeLow = 1 + }); + + private enum HttpRequestTarget + { + Unknown = -1, + // origin-form is the most common + OriginForm, + AbsoluteForm, + AuthorityForm, + AsteriskForm + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2StreamContext.cs b/src/Kestrel.Core/Internal/Http2/Http2StreamContext.cs new file mode 100644 index 0000000000..af7174c754 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2StreamContext.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; +using System.Net; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2StreamContext + { + public string ConnectionId { get; set; } + public int StreamId { get; set; } + public ServiceContext ServiceContext { get; set; } + public PipeFactory PipeFactory { get; set; } + public IPEndPoint RemoteEndPoint { get; set; } + public IPEndPoint LocalEndPoint { get; set; } + public IHttp2StreamLifetimeHandler StreamLifetimeHandler { get; set; } + public IHttp2FrameWriter FrameWriter { get; set; } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2StreamOfT.cs b/src/Kestrel.Core/Internal/Http2/Http2StreamOfT.cs new file mode 100644 index 0000000000..79b7493411 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2StreamOfT.cs @@ -0,0 +1,172 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Protocols; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public class Http2Stream : Http2Stream + { + private readonly IHttpApplication _application; + + public Http2Stream(IHttpApplication application, Http2StreamContext context) + : base(context) + { + _application = application; + } + + public override async Task ProcessRequestAsync() + { + try + { + Method = RequestHeaders[":method"]; + Scheme = RequestHeaders[":scheme"]; + + var path = RequestHeaders[":path"].ToString(); + var queryIndex = path.IndexOf('?'); + + Path = queryIndex == -1 ? path : path.Substring(0, queryIndex); + QueryString = queryIndex == -1 ? string.Empty : path.Substring(queryIndex); + + RequestHeaders["Host"] = RequestHeaders[":authority"]; + + // TODO: figure out what the equivalent for HTTP/2 is + // EnsureHostHeaderExists(); + + MessageBody = Http2MessageBody.For(FrameRequestHeaders, this); + + InitializeStreams(MessageBody); + + var context = _application.CreateContext(this); + try + { + try + { + //KestrelEventSource.Log.RequestStart(this); + + await _application.ProcessRequestAsync(context); + + if (Volatile.Read(ref _requestAborted) == 0) + { + VerifyResponseContentLength(); + } + } + catch (Exception ex) + { + ReportApplicationError(ex); + + if (ex is BadHttpRequestException) + { + throw; + } + } + finally + { + //KestrelEventSource.Log.RequestStop(this); + + // Trigger OnStarting if it hasn't been called yet and the app hasn't + // already failed. If an OnStarting callback throws we can go through + // our normal error handling in ProduceEnd. + // https://github.com/aspnet/KestrelHttpServer/issues/43 + if (!HasResponseStarted && _applicationException == null && _onStarting != null) + { + await FireOnStarting(); + } + + PauseStreams(); + + if (_onCompleted != null) + { + await FireOnCompleted(); + } + } + + // If _requestAbort is set, the connection has already been closed. + if (Volatile.Read(ref _requestAborted) == 0) + { + await ProduceEnd(); + } + else if (!HasResponseStarted) + { + // If the request was aborted and no response was sent, there's no + // meaningful status code to log. + StatusCode = 0; + } + } + catch (BadHttpRequestException ex) + { + // Handle BadHttpRequestException thrown during app execution or remaining message body consumption. + // This has to be caught here so StatusCode is set properly before disposing the HttpContext + // (DisposeContext logs StatusCode). + SetBadRequestState(ex); + } + finally + { + _application.DisposeContext(context, _applicationException); + + // StopStreams should be called before the end of the "if (!_requestProcessingStopping)" block + // to ensure InitializeStreams has been called. + StopStreams(); + + if (HasStartedConsumingRequestBody) + { + RequestBodyPipe.Reader.Complete(); + + // Wait for MessageBody.PumpAsync() to call RequestBodyPipe.Writer.Complete(). + await MessageBody.StopAsync(); + + // At this point both the request body pipe reader and writer should be completed. + RequestBodyPipe.Reset(); + } + } + } + catch (BadHttpRequestException ex) + { + // Handle BadHttpRequestException thrown during request line or header parsing. + // SetBadRequestState logs the error. + SetBadRequestState(ex); + } + catch (ConnectionResetException ex) + { + // Don't log ECONNRESET errors made between requests. Browsers like IE will reset connections regularly. + if (_requestProcessingStatus != RequestProcessingStatus.RequestPending) + { + Log.RequestProcessingError(ConnectionId, ex); + } + } + catch (IOException ex) + { + Log.RequestProcessingError(ConnectionId, ex); + } + catch (Exception ex) + { + Log.LogWarning(0, ex, CoreStrings.RequestProcessingEndError); + } + finally + { + try + { + if (Volatile.Read(ref _requestAborted) == 0) + { + await TryProduceInvalidRequestResponse(); + } + } + catch (Exception ex) + { + Log.LogWarning(0, ex, CoreStrings.ConnectionShutdownError); + } + finally + { + StreamLifetimeHandler.OnStreamCompleted(StreamId); + } + } + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/Http2Streams.cs b/src/Kestrel.Core/Internal/Http2/Http2Streams.cs new file mode 100644 index 0000000000..b433f7c22c --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/Http2Streams.cs @@ -0,0 +1,48 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + internal class Http2Streams + { + private readonly FrameResponseStream _response; + private readonly FrameRequestStream _request; + + public Http2Streams(IHttpBodyControlFeature bodyControl, IFrameControl httpStreamControl) + { + _request = new FrameRequestStream(bodyControl); + _response = new FrameResponseStream(bodyControl, httpStreamControl); + } + + public (Stream request, Stream response) Start(Http2MessageBody body) + { + _request.StartAcceptingReads(body); + _response.StartAcceptingWrites(); + + return (_request, _response); + } + + public void Pause() + { + _request.PauseAcceptingReads(); + _response.PauseAcceptingWrites(); + } + + public void Stop() + { + _request.StopAcceptingReads(); + _response.StopAcceptingWrites(); + } + + public void Abort(Exception error) + { + _request.Abort(error); + _response.Abort(); + } + } +} diff --git a/src/Kestrel.Core/Internal/Http2/IHttp2FrameWriter.cs b/src/Kestrel.Core/Internal/Http2/IHttp2FrameWriter.cs new file mode 100644 index 0000000000..c4615dac51 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/IHttp2FrameWriter.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public interface IHttp2FrameWriter + { + void Abort(Exception error); + Task FlushAsync(CancellationToken cancellationToken); + Task Write100ContinueAsync(int streamId); + Task WriteHeadersAsync(int streamId, int statusCode, IHeaderDictionary headers); + Task WriteDataAsync(int streamId, Span data, CancellationToken cancellationToken); + Task WriteDataAsync(int streamId, Span data, bool endStream, CancellationToken cancellationToken); + Task WriteRstStreamAsync(int streamId, Http2ErrorCode errorCode); + Task WriteSettingsAckAsync(); + Task WritePingAsync(Http2PingFrameFlags flags, Span payload); + Task WriteGoAwayAsync(int lastStreamId, Http2ErrorCode errorCode); + } +} diff --git a/src/Kestrel.Core/Internal/Http2/IHttp2StreamLifetimeHandler.cs b/src/Kestrel.Core/Internal/Http2/IHttp2StreamLifetimeHandler.cs new file mode 100644 index 0000000000..fcb9c89637 --- /dev/null +++ b/src/Kestrel.Core/Internal/Http2/IHttp2StreamLifetimeHandler.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 +{ + public interface IHttp2StreamLifetimeHandler + { + void OnStreamCompleted(int streamId); + } +} diff --git a/src/Kestrel.Tls/ClosedStream.cs b/src/Kestrel.Tls/ClosedStream.cs new file mode 100644 index 0000000000..511869b29f --- /dev/null +++ b/src/Kestrel.Tls/ClosedStream.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + internal class ClosedStream : Stream + { + private static readonly Task ZeroResultTask = Task.FromResult(result: 0); + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + + public override long Length + { + get + { + throw new NotSupportedException(); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(); + } + set + { + throw new NotSupportedException(); + } + } + + public override void Flush() + { + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return 0; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ZeroResultTask; + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + } +} diff --git a/src/Kestrel.Tls/Kestrel.Tls.csproj b/src/Kestrel.Tls/Kestrel.Tls.csproj new file mode 100644 index 0000000000..3a7bc3d924 --- /dev/null +++ b/src/Kestrel.Tls/Kestrel.Tls.csproj @@ -0,0 +1,26 @@ + + + + + + Microsoft.AspNetCore.Server.Kestrel.Tls + Microsoft.AspNetCore.Server.Kestrel.Tls + netstandard2.0 + true + aspnetcore;kestrel + CS1591;$(NoWarn) + false + + CurrentRuntime + true + + + + + + + + + + + diff --git a/src/Kestrel.Tls/ListenOptionsTlsExtensions.cs b/src/Kestrel.Tls/ListenOptionsTlsExtensions.cs new file mode 100644 index 0000000000..85011e6227 --- /dev/null +++ b/src/Kestrel.Tls/ListenOptionsTlsExtensions.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Tls; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Hosting +{ + public static class ListenOptionsTlsExtensions + { + public static ListenOptions UseTls(this ListenOptions listenOptions, string certificatePath, string privateKeyPath) + { + return listenOptions.UseTls(new TlsConnectionAdapterOptions + { + CertificatePath = certificatePath, + PrivateKeyPath = privateKeyPath + }); + } + + public static ListenOptions UseTls(this ListenOptions listenOptions, TlsConnectionAdapterOptions tlsOptions) + { + var loggerFactory = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService(); + listenOptions.ConnectionAdapters.Add(new TlsConnectionAdapter(tlsOptions, loggerFactory)); + return listenOptions; + } + } +} diff --git a/src/Kestrel.Tls/OpenSsl.cs b/src/Kestrel.Tls/OpenSsl.cs new file mode 100644 index 0000000000..17568e4b9c --- /dev/null +++ b/src/Kestrel.Tls/OpenSsl.cs @@ -0,0 +1,268 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + public static class OpenSsl + { + public const int OPENSSL_NPN_NEGOTIATED = 1; + public const int SSL_TLSEXT_ERR_OK = 0; + public const int SSL_TLSEXT_ERR_NOACK = 3; + + private const int BIO_C_SET_BUF_MEM_EOF_RETURN = 130; + private const int SSL_CTRL_SET_ECDH_AUTO = 94; + + public static int SSL_library_init() + { + return NativeMethods.SSL_library_init(); + } + + public static void SSL_load_error_strings() + { + NativeMethods.SSL_load_error_strings(); + } + + public static void OpenSSL_add_all_algorithms() + { + NativeMethods.OPENSSL_add_all_algorithms_noconf(); + } + + public static IntPtr TLSv1_2_method() + { + return NativeMethods.TLSv1_2_method(); + } + + public static IntPtr SSL_CTX_new(IntPtr method) + { + return NativeMethods.SSL_CTX_new(method); + } + + public static void SSL_CTX_free(IntPtr ctx) + { + NativeMethods.SSL_CTX_free(ctx); + } + + public static int SSL_CTX_set_ecdh_auto(IntPtr ctx, int onoff) + { + return (int)NativeMethods.SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, IntPtr.Zero); + } + + public static int SSL_CTX_use_certificate_file(IntPtr ctx, string file, int type) + { + var ptr = Marshal.StringToHGlobalAnsi(file); + var error = NativeMethods.SSL_CTX_use_certificate_file(ctx, ptr, type); + Marshal.FreeHGlobal(ptr); + + return error; + } + + public static int SSL_CTX_use_PrivateKey_file(IntPtr ctx, string file, int type) + { + var ptr = Marshal.StringToHGlobalAnsi(file); + var error = NativeMethods.SSL_CTX_use_PrivateKey_file(ctx, ptr, type); + Marshal.FreeHGlobal(ptr); + + return error; + } + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public unsafe delegate int alpn_select_cb_t(IntPtr ssl, out byte* @out, out byte outlen, byte* @in, uint inlen, IntPtr arg); + + public unsafe static void SSL_CTX_set_alpn_select_cb(IntPtr ctx, alpn_select_cb_t cb, IntPtr arg) + { + NativeMethods.SSL_CTX_set_alpn_select_cb(ctx, cb, arg); + } + + public static unsafe int SSL_select_next_proto(out byte* @out, out byte outlen, byte* server, uint server_len, byte* client, uint client_len) + { + return NativeMethods.SSL_select_next_proto(out @out, out outlen, server, server_len, client, client_len); + } + + public static unsafe void SSL_get0_alpn_selected(IntPtr ssl, out string protocol) + { + NativeMethods.SSL_get0_alpn_selected(ssl, out var data, out var length); + + protocol = data != null + ? Marshal.PtrToStringAnsi((IntPtr)data, length) + : null; + } + + public static IntPtr SSL_new(IntPtr ctx) + { + return NativeMethods.SSL_new(ctx); + } + + public static void SSL_free(IntPtr ssl) + { + NativeMethods.SSL_free(ssl); + } + + public static int SSL_get_error(IntPtr ssl, int ret) + { + return NativeMethods.SSL_get_error(ssl, ret); + } + + public static void SSL_set_accept_state(IntPtr ssl) + { + NativeMethods.SSL_set_accept_state(ssl); + } + + public static int SSL_do_handshake(IntPtr ssl) + { + return NativeMethods.SSL_do_handshake(ssl); + } + + public static unsafe int SSL_read(IntPtr ssl, byte[] buffer, int offset, int count) + { + fixed (byte* ptr = buffer) + { + return NativeMethods.SSL_read(ssl, (IntPtr)(ptr + offset), count); + } + } + + public static unsafe int SSL_write(IntPtr ssl, byte[] buffer, int offset, int count) + { + fixed (byte* ptr = buffer) + { + return NativeMethods.SSL_write(ssl, (IntPtr)(ptr + offset), count); + } + } + + public static void SSL_set_bio(IntPtr ssl, IntPtr rbio, IntPtr wbio) + { + NativeMethods.SSL_set_bio(ssl, rbio, wbio); + } + + public static IntPtr BIO_new(IntPtr type) + { + return NativeMethods.BIO_new(type); + } + + public static unsafe int BIO_read(IntPtr b, byte[] buffer, int offset, int count) + { + fixed (byte* ptr = buffer) + { + return NativeMethods.BIO_read(b, (IntPtr)(ptr + offset), count); + } + } + + public static unsafe int BIO_write(IntPtr b, byte[] buffer, int offset, int count) + { + fixed (byte* ptr = buffer) + { + return NativeMethods.BIO_write(b, (IntPtr)(ptr + offset), count); + } + } + + public static long BIO_ctrl_pending(IntPtr b) + { + return NativeMethods.BIO_ctrl_pending(b); + } + + public static long BIO_set_mem_eof_return(IntPtr b, int v) + { + return NativeMethods.BIO_ctrl(b, BIO_C_SET_BUF_MEM_EOF_RETURN, v, IntPtr.Zero); + } + + public static IntPtr BIO_s_mem() + { + return NativeMethods.BIO_s_mem(); + } + + public static void ERR_load_BIO_strings() + { + NativeMethods.ERR_load_BIO_strings(); + } + + private class NativeMethods + { + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_library_init(); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void SSL_load_error_strings(); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void OPENSSL_add_all_algorithms_noconf(); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr TLSv1_2_method(); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr SSL_CTX_new(IntPtr method); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr SSL_CTX_free(IntPtr ctx); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern long SSL_CTX_ctrl(IntPtr ctx, int cmd, long larg, IntPtr parg); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_CTX_use_certificate_file(IntPtr ctx, IntPtr file, int type); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_CTX_use_PrivateKey_file(IntPtr ctx, IntPtr file, int type); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void SSL_CTX_set_alpn_select_cb(IntPtr ctx, alpn_select_cb_t cb, IntPtr arg); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern unsafe int SSL_select_next_proto(out byte* @out, out byte outlen, byte* server, uint server_len, byte* client, uint client_len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern unsafe void SSL_get0_alpn_selected(IntPtr ssl, out byte* data, out int len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr SSL_new(IntPtr ctx); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr SSL_free(IntPtr ssl); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_get_error(IntPtr ssl, int ret); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void SSL_set_accept_state(IntPtr ssl); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_do_handshake(IntPtr ssl); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_read(IntPtr ssl, IntPtr buf, int len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int SSL_write(IntPtr ssl, IntPtr buf, int len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void SSL_set_bio(IntPtr ssl, IntPtr rbio, IntPtr wbio); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr BIO_new(IntPtr type); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int BIO_read(IntPtr b, IntPtr buf, int len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern int BIO_write(IntPtr b, IntPtr buf, int len); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern long BIO_ctrl(IntPtr bp, int cmd, long larg, IntPtr parg); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern long BIO_ctrl_pending(IntPtr bp); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr BIO_s_mem(); + + [DllImport("libssl", CallingConvention = CallingConvention.Cdecl)] + public static extern void ERR_load_BIO_strings(); + } + } +} diff --git a/src/Kestrel.Tls/Properties/AssemblyInfo.cs b/src/Kestrel.Tls/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..3bb5150c92 --- /dev/null +++ b/src/Kestrel.Tls/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Server.Kestrel.FunctionalTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Kestrel.Tls/README.md b/src/Kestrel.Tls/README.md new file mode 100644 index 0000000000..9012206247 --- /dev/null +++ b/src/Kestrel.Tls/README.md @@ -0,0 +1,5 @@ +# NOT FOR PRODUCTION USE + +The code in this package contains the bare minimum to make Kestrel work with TLS 1.2 with ALPN support. It has not been audited nor hardened in any way. DO NOT USE THIS IN PRODUCTION. + +This package is temporary and will be removed once `SslStream` supports ALPN. diff --git a/src/Kestrel.Tls/TlsApplicationProtocolFeature.cs b/src/Kestrel.Tls/TlsApplicationProtocolFeature.cs new file mode 100644 index 0000000000..b3bcee74cc --- /dev/null +++ b/src/Kestrel.Tls/TlsApplicationProtocolFeature.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + internal class TlsApplicationProtocolFeature : ITlsApplicationProtocolFeature + { + public TlsApplicationProtocolFeature(string applicationProtocol) + { + ApplicationProtocol = applicationProtocol; + } + + public string ApplicationProtocol { get; } + } +} diff --git a/src/Kestrel.Tls/TlsConnectionAdapter.cs b/src/Kestrel.Tls/TlsConnectionAdapter.cs new file mode 100644 index 0000000000..0dae1b701c --- /dev/null +++ b/src/Kestrel.Tls/TlsConnectionAdapter.cs @@ -0,0 +1,109 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + public class TlsConnectionAdapter : IConnectionAdapter + { + private static readonly ClosedAdaptedConnection _closedAdaptedConnection = new ClosedAdaptedConnection(); + private static readonly HashSet _serverProtocols = new HashSet(new[] { "h2", "http/1.1" }); + + private readonly TlsConnectionAdapterOptions _options; + private readonly ILogger _logger; + + private string _applicationProtocol; + + public TlsConnectionAdapter(TlsConnectionAdapterOptions options) + : this(options, loggerFactory: null) + { + } + + public TlsConnectionAdapter(TlsConnectionAdapterOptions options, ILoggerFactory loggerFactory) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (options.CertificatePath == null) + { + throw new ArgumentException("Certificate path must be non-null.", nameof(options)); + } + + if (options.PrivateKeyPath == null) + { + throw new ArgumentException("Private key path must be non-null.", nameof(options)); + } + + _options = options; + _logger = loggerFactory?.CreateLogger(nameof(TlsConnectionAdapter)); + } + + public bool IsHttps => true; + + public Task OnConnectionAsync(ConnectionAdapterContext context) + { + // Don't trust TlsStream not to block. + return Task.Run(() => InnerOnConnectionAsync(context)); + } + + private async Task InnerOnConnectionAsync(ConnectionAdapterContext context) + { + var tlsStream = new TlsStream(context.ConnectionStream, _options.CertificatePath, _options.PrivateKeyPath, _serverProtocols); + + try + { + await tlsStream.DoHandshakeAsync(); + _applicationProtocol = tlsStream.GetNegotiatedApplicationProtocol(); + } + catch (IOException ex) + { + _logger?.LogInformation(1, ex, "Authentication failed."); + tlsStream.Dispose(); + return _closedAdaptedConnection; + } + + // Always set the feature even though the cert might be null + context.Features.Set(new TlsConnectionFeature()); + context.Features.Set(new TlsApplicationProtocolFeature(_applicationProtocol)); + + return new TlsAdaptedConnection(tlsStream); + } + + private class TlsAdaptedConnection : IAdaptedConnection + { + private readonly TlsStream _tlsStream; + + public TlsAdaptedConnection(TlsStream tlsStream) + { + _tlsStream = tlsStream; + } + + public Stream ConnectionStream => _tlsStream; + + public void Dispose() + { + _tlsStream.Dispose(); + } + } + + private class ClosedAdaptedConnection : IAdaptedConnection + { + public Stream ConnectionStream { get; } = new ClosedStream(); + + public void Dispose() + { + } + } + } +} diff --git a/src/Kestrel.Tls/TlsConnectionAdapterOptions.cs b/src/Kestrel.Tls/TlsConnectionAdapterOptions.cs new file mode 100644 index 0000000000..0d49b62b69 --- /dev/null +++ b/src/Kestrel.Tls/TlsConnectionAdapterOptions.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + public class TlsConnectionAdapterOptions + { + public string CertificatePath { get; set; } = string.Empty; + + public string PrivateKeyPath { get; set; } = string.Empty; + } +} diff --git a/src/Kestrel.Tls/TlsConnectionFeature.cs b/src/Kestrel.Tls/TlsConnectionFeature.cs new file mode 100644 index 0000000000..007fbd0f8a --- /dev/null +++ b/src/Kestrel.Tls/TlsConnectionFeature.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + internal class TlsConnectionFeature : ITlsConnectionFeature + { + public X509Certificate2 ClientCertificate { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + } +} diff --git a/src/Kestrel.Tls/TlsStream.cs b/src/Kestrel.Tls/TlsStream.cs new file mode 100644 index 0000000000..6cf8c5e8e0 --- /dev/null +++ b/src/Kestrel.Tls/TlsStream.cs @@ -0,0 +1,231 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Tls +{ + public class TlsStream : Stream + { + private static unsafe OpenSsl.alpn_select_cb_t _alpnSelectCallback = AlpnSelectCallback; + + private readonly Stream _innerStream; + private readonly byte[] _protocols; + private readonly GCHandle _protocolsHandle; + + private IntPtr _ctx; + private IntPtr _ssl; + private IntPtr _inputBio; + private IntPtr _outputBio; + + private readonly byte[] _inputBuffer = new byte[1024 * 1024]; + private readonly byte[] _outputBuffer = new byte[1024 * 1024]; + + static TlsStream() + { + OpenSsl.SSL_library_init(); + OpenSsl.SSL_load_error_strings(); + OpenSsl.ERR_load_BIO_strings(); + OpenSsl.OpenSSL_add_all_algorithms(); + } + + public TlsStream(Stream innerStream, string certificatePath, string privateKeyPath, IEnumerable protocols) + { + _innerStream = innerStream; + _protocols = ToWireFormat(protocols); + _protocolsHandle = GCHandle.Alloc(_protocols); + + _ctx = OpenSsl.SSL_CTX_new(OpenSsl.TLSv1_2_method()); + + if (_ctx == IntPtr.Zero) + { + throw new Exception("Unable to create SSL context."); + } + + OpenSsl.SSL_CTX_set_ecdh_auto(_ctx, 1); + + if (OpenSsl.SSL_CTX_use_certificate_file(_ctx, certificatePath, 1) != 1) + { + throw new Exception("Unable to load certificate file."); + } + + if (OpenSsl.SSL_CTX_use_PrivateKey_file(_ctx, privateKeyPath, 1) != 1) + { + throw new Exception("Unable to load private key file."); + } + + OpenSsl.SSL_CTX_set_alpn_select_cb(_ctx, _alpnSelectCallback, GCHandle.ToIntPtr(_protocolsHandle)); + + _ssl = OpenSsl.SSL_new(_ctx); + + _inputBio = OpenSsl.BIO_new(OpenSsl.BIO_s_mem()); + OpenSsl.BIO_set_mem_eof_return(_inputBio, -1); + + _outputBio = OpenSsl.BIO_new(OpenSsl.BIO_s_mem()); + OpenSsl.BIO_set_mem_eof_return(_outputBio, -1); + + OpenSsl.SSL_set_bio(_ssl, _inputBio, _outputBio); + } + + ~TlsStream() + { + if (_ssl != IntPtr.Zero) + { + OpenSsl.SSL_free(_ssl); + } + + if (_ctx != IntPtr.Zero) + { + // This frees the BIOs. + OpenSsl.SSL_CTX_free(_ctx); + } + + if (_protocolsHandle.IsAllocated) + { + _protocolsHandle.Free(); + } + } + + public override bool CanRead => true; + public override bool CanWrite => true; + + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Flush() + { + FlushAsync(default(CancellationToken)).GetAwaiter().GetResult(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override async Task FlushAsync(CancellationToken cancellationToken) + { + var pending = OpenSsl.BIO_ctrl_pending(_outputBio); + + while (pending > 0) + { + var count = OpenSsl.BIO_read(_outputBio, _outputBuffer, 0, _outputBuffer.Length); + await _innerStream.WriteAsync(_outputBuffer, 0, count, cancellationToken); + + pending = OpenSsl.BIO_ctrl_pending(_outputBio); + } + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (OpenSsl.BIO_ctrl_pending(_inputBio) == 0) + { + var bytesRead = await _innerStream.ReadAsync(_inputBuffer, 0, _inputBuffer.Length, cancellationToken); + OpenSsl.BIO_write(_inputBio, _inputBuffer, 0, bytesRead); + } + + return OpenSsl.SSL_read(_ssl, buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + OpenSsl.SSL_write(_ssl, buffer, offset, count); + + return FlushAsync(cancellationToken); + } + + public async Task DoHandshakeAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + OpenSsl.SSL_set_accept_state(_ssl); + + var count = 0; + + try + { + while ((count = await _innerStream.ReadAsync(_inputBuffer, 0, _inputBuffer.Length, cancellationToken)) > 0) + { + if (count == 0) + { + throw new IOException("TLS handshake failed: the inner stream was closed."); + } + + OpenSsl.BIO_write(_inputBio, _inputBuffer, 0, count); + + var ret = OpenSsl.SSL_do_handshake(_ssl); + + if (ret != 1) + { + var error = OpenSsl.SSL_get_error(_ssl, ret); + + if (error != 2) + { + throw new IOException($"TLS handshake failed: {nameof(OpenSsl.SSL_do_handshake)} error {error}."); + } + } + + await FlushAsync(cancellationToken); + + if (ret == 1) + { + return; + } + } + } + finally + { + _protocolsHandle.Free(); + } + } + + public string GetNegotiatedApplicationProtocol() + { + OpenSsl.SSL_get0_alpn_selected(_ssl, out var protocol); + return protocol; + } + + private static unsafe int AlpnSelectCallback(IntPtr ssl, out byte* @out, out byte outlen, byte* @in, uint inlen, IntPtr arg) + { + var protocols = GCHandle.FromIntPtr(arg); + var server = (byte[])protocols.Target; + + fixed (byte* serverPtr = server) + { + return OpenSsl.SSL_select_next_proto(out @out, out outlen, serverPtr, (uint)server.Length, @in, (uint)inlen) == OpenSsl.OPENSSL_NPN_NEGOTIATED + ? OpenSsl.SSL_TLSEXT_ERR_OK + : OpenSsl.SSL_TLSEXT_ERR_NOACK; + } + } + + private static byte[] ToWireFormat(IEnumerable protocols) + { + var buffer = new byte[protocols.Count() + protocols.Sum(protocol => protocol.Length)]; + + var offset = 0; + foreach (var protocol in protocols) + { + buffer[offset++] = (byte)protocol.Length; + offset += Encoding.ASCII.GetBytes(protocol, 0, protocol.Length, buffer, offset); + } + + return buffer; + } + } +} diff --git a/src/Kestrel.Transport.Libuv/Internal/LibuvConnectionContext.cs b/src/Kestrel.Transport.Libuv/Internal/LibuvConnectionContext.cs index fafb3aad57..e1319c8317 100644 --- a/src/Kestrel.Transport.Libuv/Internal/LibuvConnectionContext.cs +++ b/src/Kestrel.Transport.Libuv/Internal/LibuvConnectionContext.cs @@ -20,4 +20,4 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal public override IScheduler InputWriterScheduler => ListenerContext.Thread; public override IScheduler OutputReaderScheduler => ListenerContext.Thread; } -} \ No newline at end of file +} diff --git a/test/Kestrel.Core.Tests/FrameConnectionTests.cs b/test/Kestrel.Core.Tests/FrameConnectionTests.cs index 0304ce3f7f..1891872223 100644 --- a/test/Kestrel.Core.Tests/FrameConnectionTests.cs +++ b/test/Kestrel.Core.Tests/FrameConnectionTests.cs @@ -532,7 +532,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public async Task StartRequestProcessingCreatesLogScopeWithConnectionId() { _frameConnection.StartRequestProcessing(new DummyApplication()); - + var scopeObjects = ((TestKestrelTrace)_frameConnectionContext.ServiceContext.Log) .Logger .Scopes diff --git a/test/Kestrel.Core.Tests/HPackEncoderTests.cs b/test/Kestrel.Core.Tests/HPackEncoderTests.cs new file mode 100644 index 0000000000..de07eeec72 --- /dev/null +++ b/test/Kestrel.Core.Tests/HPackEncoderTests.cs @@ -0,0 +1,130 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HPackEncoderTests + { + [Fact] + public void EncodesHeadersInSinglePayloadWhenSpaceAvailable() + { + var encoder = new HPackEncoder(); + + var statusCode = 200; + var headers = new [] + { + new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), + new KeyValuePair("content-type", "text/html; charset=utf-8"), + new KeyValuePair("server", "Kestrel") + }; + + var expectedPayload = new byte[] + { + 0x88, 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, + 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, + 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, + 0x20, 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, + 0x30, 0x20, 0x47, 0x4d, 0x54, 0x00, 0x0c, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x74, 0x65, 0x78, 0x74, + 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x3b, 0x20, 0x63, + 0x68, 0x61, 0x72, 0x73, 0x65, 0x74, 0x3d, 0x75, + 0x74, 0x66, 0x2d, 0x38, 0x00, 0x06, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x07, 0x4b, 0x65, 0x73, + 0x74, 0x72, 0x65, 0x6c + }; + + var payload = new byte[1024]; + Assert.True(encoder.BeginEncode(statusCode, headers, payload, out var length)); + Assert.Equal(expectedPayload.Length, length); + + for (var i = 0; i < length; i++) + { + Assert.True(expectedPayload[i] == payload[i], $"{expectedPayload[i]} != {payload[i]} at {i} (len {length})"); + } + + Assert.Equal(expectedPayload, new ArraySegment(payload, 0, length)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodesHeadersInMultiplePayloadsWhenSpaceNotAvailable(bool exactSize) + { + var encoder = new HPackEncoder(); + + var statusCode = 200; + var headers = new [] + { + new KeyValuePair("date", "Mon, 24 Jul 2017 19:22:30 GMT"), + new KeyValuePair("content-type", "text/html; charset=utf-8"), + new KeyValuePair("server", "Kestrel") + }; + + var expectedStatusCodePayload = new byte[] + { + 0x88 + }; + + var expectedDateHeaderPayload = new byte[] + { + 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, + 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x34, 0x20, 0x4a, + 0x75, 0x6c, 0x20, 0x32, 0x30, 0x31, 0x37, 0x20, + 0x31, 0x39, 0x3a, 0x32, 0x32, 0x3a, 0x33, 0x30, + 0x20, 0x47, 0x4d, 0x54 + }; + + var expectedContentTypeHeaderPayload = new byte[] + { + 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, + 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x18, 0x74, + 0x65, 0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, + 0x3b, 0x20, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, + 0x74, 0x3d, 0x75, 0x74, 0x66, 0x2d, 0x38 + }; + + var expectedServerHeaderPayload = new byte[] + { + 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x07, 0x4b, 0x65, 0x73, 0x74, 0x72, 0x65, 0x6c + }; + + Span payload = new byte[1024]; + var offset = 0; + + // When !exactSize, slices are one byte short of fitting the next header + var sliceLength = expectedStatusCodePayload.Length + (exactSize ? 0 : expectedDateHeaderPayload.Length - 1); + Assert.False(encoder.BeginEncode(statusCode, headers, payload.Slice(offset, sliceLength), out var length)); + Assert.Equal(expectedStatusCodePayload.Length, length); + Assert.Equal(expectedStatusCodePayload, payload.Slice(0, length).ToArray()); + + offset += length; + + sliceLength = expectedDateHeaderPayload.Length + (exactSize ? 0 : expectedContentTypeHeaderPayload.Length - 1); + Assert.False(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedDateHeaderPayload.Length, length); + Assert.Equal(expectedDateHeaderPayload, payload.Slice(offset, length).ToArray()); + + offset += length; + + sliceLength = expectedContentTypeHeaderPayload.Length + (exactSize ? 0 : expectedServerHeaderPayload.Length - 1); + Assert.False(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedContentTypeHeaderPayload.Length, length); + Assert.Equal(expectedContentTypeHeaderPayload, payload.Slice(offset, length).ToArray()); + + offset += length; + + sliceLength = expectedServerHeaderPayload.Length; + Assert.True(encoder.Encode(payload.Slice(offset, sliceLength), out length)); + Assert.Equal(expectedServerHeaderPayload.Length, length); + Assert.Equal(expectedServerHeaderPayload, payload.Slice(offset, length).ToArray()); + } + } +} diff --git a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs new file mode 100644 index 0000000000..0e1e4668fc --- /dev/null +++ b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs @@ -0,0 +1,1520 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class Http2ConnectionTests : IDisposable + { + private static readonly string _largeHeaderA = new string('a', Http2Frame.MinAllowedMaxFrameSize - Http2Frame.HeaderLength - 8); + + private static readonly string _largeHeaderB = new string('b', Http2Frame.MinAllowedMaxFrameSize - Http2Frame.HeaderLength - 8); + + private static readonly IEnumerable> _postRequestHeaders = new [] + { + new KeyValuePair(":method", "POST"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "https"), + }; + + private static readonly IEnumerable> _browserRequestHeaders = new [] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "https"), + new KeyValuePair("user-agent", "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:54.0) Gecko/20100101 Firefox/54.0"), + new KeyValuePair("accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), + new KeyValuePair("accept-language", "en-US,en;q=0.5"), + new KeyValuePair("accept-encoding", "gzip, deflate, br"), + new KeyValuePair("upgrade-insecure-requests", "1"), + }; + + private static readonly IEnumerable> _oneContinuationRequestHeaders = new [] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "https"), + new KeyValuePair("a", _largeHeaderA) + }; + + private static readonly IEnumerable> _twoContinuationsRequestHeaders = new [] + { + new KeyValuePair(":method", "GET"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "https"), + new KeyValuePair("a", _largeHeaderA), + new KeyValuePair("b", _largeHeaderB) + }; + + private static readonly byte[] _helloBytes = Encoding.ASCII.GetBytes("hello"); + private static readonly byte[] _worldBytes = Encoding.ASCII.GetBytes("world"); + private static readonly byte[] _helloWorldBytes = Encoding.ASCII.GetBytes("hello, world"); + private static readonly byte[] _noData = new byte[0]; + + private readonly PipeFactory _pipeFactory = new PipeFactory(); + private readonly IPipe _inputPipe; + private readonly IPipe _outputPipe; + private readonly Http2ConnectionContext _connectionContext; + private readonly Http2Connection _connection; + private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); + private readonly HPackDecoder _hpackDecoder = new HPackDecoder(); + private readonly Http2PeerSettings _clientSettings = new Http2PeerSettings(); + + private readonly ConcurrentDictionary> _runningStreams = new ConcurrentDictionary>(); + private readonly Dictionary _receivedHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + private readonly HashSet _abortedStreamIds = new HashSet(); + private readonly object _abortedStreamIdsLock = new object(); + + private readonly RequestDelegate _noopApplication; + private readonly RequestDelegate _readHeadersApplication; + private readonly RequestDelegate _bufferingApplication; + private readonly RequestDelegate _echoApplication; + private readonly RequestDelegate _echoWaitForAbortApplication; + private readonly RequestDelegate _largeHeadersApplication; + private readonly RequestDelegate _waitForAbortApplication; + private readonly RequestDelegate _waitForAbortFlushingApplication; + + private Task _connectionTask; + + public Http2ConnectionTests() + { + _inputPipe = _pipeFactory.Create(); + _outputPipe = _pipeFactory.Create(); + + _noopApplication = context => Task.CompletedTask; + + _readHeadersApplication = context => + { + foreach (var header in context.Request.Headers) + { + _receivedHeaders[header.Key] = header.Value.ToString(); + } + + return Task.CompletedTask; + }; + + _bufferingApplication = async context => + { + var data = new List(); + var buffer = new byte[1024]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + data.AddRange(new ArraySegment(buffer, 0, received)); + } + + await context.Response.Body.WriteAsync(data.ToArray(), 0, data.Count); + }; + + _echoApplication = async context => + { + var buffer = new byte[Http2Frame.MinAllowedMaxFrameSize]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + await context.Response.Body.WriteAsync(buffer, 0, received); + } + }; + + _echoWaitForAbortApplication = async context => + { + var buffer = new byte[Http2Frame.MinAllowedMaxFrameSize]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + await context.Response.Body.WriteAsync(buffer, 0, received); + } + + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + sem.Release(); + }); + + await sem.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10)); + }; + + _largeHeadersApplication = context => + { + context.Response.Headers["a"] = _largeHeaderA; + context.Response.Headers["b"] = _largeHeaderB; + + return Task.CompletedTask; + }; + + _waitForAbortApplication = async context => + { + var streamIdFeature = context.Features.Get(); + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + sem.Release(); + }); + + await sem.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10)); + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }; + + _waitForAbortFlushingApplication = async context => + { + var streamIdFeature = context.Features.Get(); + var sem = new SemaphoreSlim(0); + + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + sem.Release(); + }); + + await sem.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10)); + + await context.Response.Body.FlushAsync(); + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }; + + _connectionContext = new Http2ConnectionContext + { + ServiceContext = new TestServiceContext(), + PipeFactory = _pipeFactory, + Input = _inputPipe.Reader, + Output = _outputPipe + }; + _connection = new Http2Connection(_connectionContext); + } + + public void Dispose() + { + _pipeFactory.Dispose(); + } + + [Fact] + public async Task DATA_Received_ReadByStream() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_Multiple_ReadByStream() + { + await InitializeConnectionAsync(_bufferingApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + + for (var i = 0; i < _helloWorldBytes.Length; i++) + { + await SendDataAsync(1, new ArraySegment(_helloWorldBytes, i, 1), endStream: false); + } + + await SendDataAsync(1, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_Multiplexed_ReadByStreams() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await StartStreamAsync(3, _browserRequestHeaders, endStream: false); + + await SendDataAsync(1, _helloBytes, endStream: false); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var stream1DataFrame1 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + await SendDataAsync(3, _helloBytes, endStream: false); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + var stream3DataFrame1 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + await SendDataAsync(3, _worldBytes, endStream: false); + + var stream3DataFrame2 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 3); + + await SendDataAsync(1, _worldBytes, endStream: false); + + var stream1DataFrame2 = await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + + await SendDataAsync(1, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await SendDataAsync(3, _noData, endStream: true); + + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + + Assert.Equal(stream1DataFrame1.DataPayload, _helloBytes); + Assert.Equal(stream1DataFrame2.DataPayload, _worldBytes); + Assert.Equal(stream3DataFrame1.DataPayload, _helloBytes); + Assert.Equal(stream3DataFrame2.DataPayload, _worldBytes); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task DATA_Received_WithPadding_ReadByStream(byte padLength) + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendDataWithPaddingAsync(1, _helloWorldBytes, padLength, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Equal(dataFrame.DataPayload, _helloWorldBytes); + } + + [Fact] + public async Task DATA_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendDataAsync(0, _noData, endStream: false); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task DATA_Received_PaddingEqualToFramePayloadLength_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 5, padLength: 5); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 1, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task DATA_Received_PaddingGreaterThanFramePayloadLength_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 5, padLength: 6); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 1, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task DATA_Received_FrameLengthZeroPaddingZero_ConnectionError() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await SendInvalidDataFrameAsync(1, frameLength: 0, padLength: 0); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 1, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task DATA_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task DATA_Received_StreamIdle_StreamError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendDataAsync(1, _helloWorldBytes, endStream: false); + + await WaitForStreamErrorAsync(expectedStreamId: 1, expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, ignoreNonRstStreamFrames: false); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task DATA_Received_StreamHalfClosedRemote_StreamError() + { + await InitializeConnectionAsync(_echoWaitForAbortApplication); + + await StartStreamAsync(1, _postRequestHeaders, endStream: false); + await SendDataAsync(1, _helloBytes, endStream: true); + await SendDataAsync(1, _worldBytes, endStream: true); + + await WaitForStreamErrorAsync(expectedStreamId: 1, expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, ignoreNonRstStreamFrames: true); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task DATA_Received_StreamClosed_StreamError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(1, _postRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await SendDataAsync(1, _helloWorldBytes, endStream: false); + + await WaitForStreamErrorAsync(expectedStreamId: 1, expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, ignoreNonRstStreamFrames: false); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task DATA_Received_StreamClosedImplicitly_StreamError() + { + // http://httpwg.org/specs/rfc7540.html#rfc.section.5.1.1 + // + // The first use of a new stream identifier implicitly closes all streams in the "idle" state that + // might have been initiated by that peer with a lower-valued stream identifier. For example, if a + // client sends a HEADERS frame on stream 7 without ever sending a frame on stream 5, then stream 5 + // transitions to the "closed" state when the first frame for stream 7 is sent or received. + + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 3); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 3); + + await SendDataAsync(1, _helloWorldBytes, endStream: true); + + await WaitForStreamErrorAsync(expectedStreamId: 1, expectedErrorCode: Http2ErrorCode.STREAM_CLOSED, ignoreNonRstStreamFrames: false); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_WithPadding_Decoded(byte padLength) + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPaddingAsync(1, _browserRequestHeaders, padLength, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_WithPriority_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPriorityAsync(1, _browserRequestHeaders, priority: 42, streamDependency: 0, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_WithPriorityAndPadding_Decoded(byte padLength) + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersWithPaddingAndPriorityAsync(1, _browserRequestHeaders, padLength, priority: 42, streamDependency: 0, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_browserRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await StartStreamAsync(0, _browserRequestHeaders, endStream: true); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(255)] + public async Task HEADERS_Received_PaddingEqualToFramePayloadLength_ConnectionError(byte padLength) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidHeadersFrameAsync(1, frameLength: padLength, padLength: padLength); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0, 1)] + [InlineData(1, 2)] + [InlineData(254, 255)] + public async Task HEADERS_Received_PaddingGreaterThanFramePayloadLength_ConnectionError(int frameLength, byte padLength) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidHeadersFrameAsync(1, frameLength, padLength); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task HEADERS_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendHeadersAsync(3, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PRIORITY_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPriorityAsync(0); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(4)] + [InlineData(6)] + public async Task PRIORITY_Received_LengthNotFive_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidPriorityFrameAsync(1, length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PRIORITY_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendPriorityAsync(1); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task RST_STREAM_Received_AbortsStream() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendRstStreamAsync(1); + + // No data is received from the stream since it was aborted before writing anything + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + } + + [Fact] + public async Task RST_STREAM_Received_AbortsStream_FlushedDataIsSent() + { + await InitializeConnectionAsync(_waitForAbortFlushingApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendRstStreamAsync(1); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + // No END_STREAM DATA frame is received since the stream was aborted + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.Contains(1, _abortedStreamIds); + } + + [Fact] + public async Task RST_STREAM_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendRstStreamAsync(0); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(3)] + [InlineData(5)] + public async Task RST_STREAM_Received_LengthNotFour_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + // Start stream 1 so it's legal to send it RST_STREAM frames + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await SendInvalidRstStreamFrameAsync(1, length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 1, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task RST_STREAM_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendRstStreamAsync(1); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task SETTINGS_Received_Sends_ACK() + { + await InitializeConnectionAsync(_noopApplication); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task SETTINGS_Received_StreamIdZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidStreamIdAsync(1); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, 2, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_ENABLE_PUSH, uint.MaxValue, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, (uint)int.MaxValue + 1, Http2ErrorCode.FLOW_CONTROL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_INITIAL_WINDOW_SIZE, uint.MaxValue, Http2ErrorCode.FLOW_CONTROL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 0, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 1, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 16 * 1024 - 1, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, 16 * 1024 * 1024, Http2ErrorCode.PROTOCOL_ERROR)] + [InlineData(Http2SettingsParameter.SETTINGS_MAX_FRAME_SIZE, uint.MaxValue, Http2ErrorCode.PROTOCOL_ERROR)] + public async Task SETTINGS_Received_InvalidParameterValue_ConnectionError(Http2SettingsParameter parameter, uint value, Http2ErrorCode expectedErrorCode) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidParameterValueAsync(parameter, value); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: expectedErrorCode, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task SETTINGS_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendSettingsAsync(); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(1)] + [InlineData(16 * 1024 - 9)] // Min. max. frame size minus header length + public async Task SETTINGS_Received_WithACK_LengthNotZero_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsAckWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + [InlineData(7)] + [InlineData(34)] + [InlineData(37)] + public async Task SETTINGS_Received_LengthNotMultipleOfSix_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendSettingsWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PING_Received_Sends_ACK() + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingAsync(); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + + await StopConnectionAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task PING_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendPingAsync(); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(7)] + [InlineData(9)] + public async Task PING_Received_LengthNotEight_ConnectionError(int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendPingWithInvalidLengthAsync(length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task GOAWAY_Received_ConnectionStops() + { + await InitializeConnectionAsync(_noopApplication); + + await SendGoAwayAsync(); + + await WaitForConnectionStopAsync(expectedLastStreamId: 0, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task GOAWAY_Received_AbortsAllStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + // Start some streams + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + await StartStreamAsync(5, _browserRequestHeaders, endStream: true); + + await SendGoAwayAsync(); + + await WaitForConnectionStopAsync(expectedLastStreamId: 5, ignoreNonGoAwayFrames: true); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + Assert.Contains(3, _abortedStreamIds); + Assert.Contains(5, _abortedStreamIds); + } + + [Fact] + public async Task GOAWAY_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendGoAwayAsync(); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_InterleavedWithHeaders_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _browserRequestHeaders); + await SendWindowUpdateAsync(1, sizeIncrement: 42); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Theory] + [InlineData(0, 3)] + [InlineData(0, 5)] + [InlineData(1, 3)] + [InlineData(1, 5)] + public async Task WINDOW_UPDATE_Received_LengthNotFour_ConnectionError(int streamId, int length) + { + await InitializeConnectionAsync(_noopApplication); + + await SendInvalidWindowUpdateAsync(streamId, sizeIncrement: 42, length: length); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.FRAME_SIZE_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_OnConnection_SizeIncrementZero_ConnectionError() + { + await InitializeConnectionAsync(_noopApplication); + + await SendWindowUpdateAsync(0, sizeIncrement: 0); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task WINDOW_UPDATE_Received_OnStream_SizeIncrementZero_StreamError() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await SendWindowUpdateAsync(1, sizeIncrement: 0); + + await WaitForStreamErrorAsync(expectedStreamId: 1, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonRstStreamFrames: true); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: true); + } + + [Fact] + public async Task CONTINUATION_Received_Decoded() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await StartStreamAsync(1, _twoContinuationsRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2HeadersFrameFlags.END_STREAM, + withStreamId: 1); + + VerifyDecodedRequestHeaders(_twoContinuationsRequestHeaders); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task CONTINUATION_Received_StreamIdMismatch_ConnectionError() + { + await InitializeConnectionAsync(_readHeadersApplication); + + await SendHeadersAsync(1, Http2HeadersFrameFlags.NONE, _oneContinuationRequestHeaders); + await SendContinuationAsync(3, Http2ContinuationFrameFlags.END_HEADERS); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 0, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task CONTINUATION_Sent_WhenHeadersLargerThanFrameLength() + { + await InitializeConnectionAsync(_largeHeadersApplication); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.NONE, + withStreamId: 1); + var continuationFrame1 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16373, + withFlags: (byte)Http2ContinuationFrameFlags.NONE, + withStreamId: 1); + var continuationFrame2 = await ExpectAsync(Http2FrameType.CONTINUATION, + withLength: 16373, + withFlags: (byte)Http2ContinuationFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + var responseHeaders = new FrameResponseHeaders(); + _hpackDecoder.Decode(headersFrame.HeadersPayload, responseHeaders); + _hpackDecoder.Decode(continuationFrame1.HeadersPayload, responseHeaders); + _hpackDecoder.Decode(continuationFrame2.HeadersPayload, responseHeaders); + + var responseHeadersDictionary = (IDictionary)responseHeaders; + Assert.Equal(5, responseHeadersDictionary.Count); + Assert.Contains("date", responseHeadersDictionary.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", responseHeadersDictionary[":status"]); + Assert.Equal("0", responseHeadersDictionary["content-length"]); + Assert.Equal(_largeHeaderA, responseHeadersDictionary["a"]); + Assert.Equal(_largeHeaderB, responseHeadersDictionary["b"]); + } + + [Fact] + public async Task ConnectionError_AbortsAllStreams() + { + await InitializeConnectionAsync(_waitForAbortApplication); + + // Start some streams + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await StartStreamAsync(3, _browserRequestHeaders, endStream: true); + await StartStreamAsync(5, _browserRequestHeaders, endStream: true); + + // Cause a connection error by sending an invalid frame + await SendDataAsync(0, _noData, endStream: false); + + await WaitForConnectionErrorAsync(expectedLastStreamId: 5, expectedErrorCode: Http2ErrorCode.PROTOCOL_ERROR, ignoreNonGoAwayFrames: false); + + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + Assert.Contains(3, _abortedStreamIds); + Assert.Contains(5, _abortedStreamIds); + } + + private async Task InitializeConnectionAsync(RequestDelegate application) + { + _connectionTask = _connection.ProcessAsync(new DummyApplication(application)); + + await SendPreambleAsync().ConfigureAwait(false); + await SendSettingsAsync(); + + await ExpectAsync(Http2FrameType.SETTINGS, + withLength: 0, + withFlags: 0, + withStreamId: 0); + + await ExpectAsync(Http2FrameType.SETTINGS, + withLength: 0, + withFlags: (byte)Http2SettingsFrameFlags.ACK, + withStreamId: 0); + } + + private async Task StartStreamAsync(int streamId, IEnumerable> headers, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.NONE, streamId); + var done = _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + frame.Length = length; + + if (done) + { + frame.HeadersFlags = Http2HeadersFrameFlags.END_HEADERS; + } + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + + while (!done) + { + frame.PrepareContinuation(Http2ContinuationFrameFlags.NONE, streamId); + done = _hpackEncoder.Encode(frame.HeadersPayload, out length); + frame.Length = length; + + if (done) + { + frame.ContinuationFlags = Http2ContinuationFrameFlags.END_HEADERS; + } + + await SendAsync(frame.Raw); + } + } + + private async Task SendHeadersWithPaddingAsync(int streamId, IEnumerable> headers, byte padLength, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PADDED, streamId); + frame.HeadersPadLength = padLength; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 1 + length + padLength; + frame.Payload.Slice(1 + length).Fill(0); + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private async Task SendHeadersWithPriorityAsync(int streamId, IEnumerable> headers, byte priority, int streamDependency, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PRIORITY, streamId); + frame.HeadersPriority = priority; + frame.HeadersStreamDependency = streamDependency; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 5 + length; + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private async Task SendHeadersWithPaddingAndPriorityAsync(int streamId, IEnumerable> headers, byte padLength, byte priority, int streamDependency, bool endStream) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _runningStreams[streamId] = tcs; + + var frame = new Http2Frame(); + frame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.PADDED | Http2HeadersFrameFlags.PRIORITY, streamId); + frame.HeadersPadLength = padLength; + frame.HeadersPriority = priority; + frame.HeadersStreamDependency = streamDependency; + + _hpackEncoder.BeginEncode(headers, frame.HeadersPayload, out var length); + + frame.Length = 6 + length + padLength; + frame.Payload.Slice(6 + length).Fill(0); + + if (endStream) + { + frame.HeadersFlags |= Http2HeadersFrameFlags.END_STREAM; + } + + await SendAsync(frame.Raw); + } + + private Task SendStreamDataAsync(int streamId, Span data) + { + var tasks = new List(); + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + + while (data.Length > frame.Length) + { + data.Slice(0, frame.Length).CopyTo(frame.Payload); + data = data.Slice(frame.Length); + tasks.Add(SendAsync(frame.Raw)); + } + + frame.Length = data.Length; + frame.DataFlags = Http2DataFrameFlags.END_STREAM; + data.CopyTo(frame.Payload); + tasks.Add(SendAsync(frame.Raw)); + + return Task.WhenAll(tasks); + } + + private Task WaitForAllStreamsAsync() + { + return Task.WhenAll(_runningStreams.Values.Select(tcs => tcs.Task)).TimeoutAfter(TimeSpan.FromSeconds(30)); + } + + private async Task SendAsync(ArraySegment span) + { + var writableBuffer = _inputPipe.Writer.Alloc(1); + writableBuffer.Write(span); + await writableBuffer.FlushAsync(); + } + + private Task SendPreambleAsync() => SendAsync(new ArraySegment(Http2Connection.ClientPreface)); + + private Task SendSettingsAsync() + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + return SendAsync(frame.Raw); + } + + private Task SendSettingsAckWithInvalidLengthAsync(int length) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.ACK); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidStreamIdAsync(int streamId) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + frame.StreamId = streamId; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidLengthAsync(int length) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE, _clientSettings); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendSettingsWithInvalidParameterValueAsync(Http2SettingsParameter parameter, uint value) + { + var frame = new Http2Frame(); + frame.PrepareSettings(Http2SettingsFrameFlags.NONE); + frame.Length = 6; + + frame.Payload[0] = (byte)((ushort)parameter >> 8); + frame.Payload[1] = (byte)(ushort)parameter; + frame.Payload[2] = (byte)(value >> 24); + frame.Payload[3] = (byte)(value >> 16); + frame.Payload[4] = (byte)(value >> 8); + frame.Payload[5] = (byte)value; + + return SendAsync(frame.Raw); + } + + private async Task SendHeadersAsync(int streamId, Http2HeadersFrameFlags flags, IEnumerable> headers) + { + var frame = new Http2Frame(); + + frame.PrepareHeaders(flags, streamId); + var done = _hpackEncoder.BeginEncode(headers, frame.Payload, out var length); + frame.Length = length; + + await SendAsync(frame.Raw); + + return done; + } + + private Task SendInvalidHeadersFrameAsync(int streamId, int frameLength, byte padLength) + { + Assert.True(padLength >= frameLength, $"{nameof(padLength)} must be greater than or equal to {nameof(frameLength)} to create an invalid frame."); + + var frame = new Http2Frame(); + + frame.PrepareHeaders(Http2HeadersFrameFlags.PADDED, streamId); + frame.Payload[0] = padLength; + + // Set length last so .Payload can be written to + frame.Length = frameLength; + + return SendAsync(frame.Raw); + } + + private async Task SendContinuationAsync(int streamId, Http2ContinuationFrameFlags flags) + { + var frame = new Http2Frame(); + + frame.PrepareContinuation(flags, streamId); + var done =_hpackEncoder.Encode(frame.Payload, out var length); + frame.Length = length; + + await SendAsync(frame.Raw); + + return done; + } + + private Task SendDataAsync(int streamId, Span data, bool endStream) + { + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + frame.Length = data.Length; + frame.DataFlags = endStream ? Http2DataFrameFlags.END_STREAM : Http2DataFrameFlags.NONE; + data.CopyTo(frame.DataPayload); + + return SendAsync(frame.Raw); + } + + private Task SendDataWithPaddingAsync(int streamId, Span data, byte padLength, bool endStream) + { + var frame = new Http2Frame(); + + frame.PrepareData(streamId, padLength); + frame.Length = data.Length + 1 + padLength; + data.CopyTo(frame.DataPayload); + + if (endStream) + { + frame.DataFlags |= Http2DataFrameFlags.END_STREAM; + } + + return SendAsync(frame.Raw); + } + + private Task SendInvalidDataFrameAsync(int streamId, int frameLength, byte padLength) + { + Assert.True(padLength >= frameLength, $"{nameof(padLength)} must be greater than or equal to {nameof(frameLength)} to create an invalid frame."); + + var frame = new Http2Frame(); + + frame.PrepareData(streamId); + frame.DataFlags = Http2DataFrameFlags.PADDED; + frame.Payload[0] = padLength; + + // Set length last so .Payload can be written to + frame.Length = frameLength; + + return SendAsync(frame.Raw); + } + + private Task SendPingAsync() + { + var pingFrame = new Http2Frame(); + pingFrame.PreparePing(Http2PingFrameFlags.NONE); + return SendAsync(pingFrame.Raw); + } + + private Task SendPingWithInvalidLengthAsync(int length) + { + var pingFrame = new Http2Frame(); + pingFrame.PreparePing(Http2PingFrameFlags.NONE); + pingFrame.Length = length; + return SendAsync(pingFrame.Raw); + } + + private Task SendPriorityAsync(int streamId) + { + var priorityFrame = new Http2Frame(); + priorityFrame.PreparePriority(streamId, streamDependency: 0, exclusive: false, weight: 0); + return SendAsync(priorityFrame.Raw); + } + + private Task SendInvalidPriorityFrameAsync(int streamId, int length) + { + var priorityFrame = new Http2Frame(); + priorityFrame.PreparePriority(streamId, streamDependency: 0, exclusive: false, weight: 0); + priorityFrame.Length = length; + return SendAsync(priorityFrame.Raw); + } + + private Task SendRstStreamAsync(int streamId) + { + var rstStreamFrame = new Http2Frame(); + rstStreamFrame.PrepareRstStream(streamId, Http2ErrorCode.CANCEL); + return SendAsync(rstStreamFrame.Raw); + } + + private Task SendInvalidRstStreamFrameAsync(int streamId, int length) + { + var frame = new Http2Frame(); + frame.PrepareRstStream(streamId, Http2ErrorCode.CANCEL); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private Task SendGoAwayAsync() + { + var frame = new Http2Frame(); + frame.PrepareGoAway(0, Http2ErrorCode.NO_ERROR); + return SendAsync(frame.Raw); + } + + private Task SendWindowUpdateAsync(int streamId, int sizeIncrement) + { + var frame = new Http2Frame(); + frame.PrepareWindowUpdate(streamId, sizeIncrement); + return SendAsync(frame.Raw); + } + + private Task SendInvalidWindowUpdateAsync(int streamId, int sizeIncrement, int length) + { + var frame = new Http2Frame(); + frame.PrepareWindowUpdate(streamId, sizeIncrement); + frame.Length = length; + return SendAsync(frame.Raw); + } + + private async Task ReceiveFrameAsync() + { + var frame = new Http2Frame(); + + while (true) + { + var result = await _outputPipe.Reader.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; + + try + { + Assert.True(buffer.Length > 0); + + if (Http2FrameReader.ReadFrame(buffer, frame, out consumed, out examined)) + { + return frame; + } + } + finally + { + _outputPipe.Reader.Advance(consumed, examined); + } + } + } + + private async Task ReceiveSettingsAck() + { + var frame = await ReceiveFrameAsync(); + + Assert.Equal(Http2FrameType.SETTINGS, frame.Type); + Assert.Equal(Http2SettingsFrameFlags.ACK, frame.SettingsFlags); + } + + private async Task ExpectAsync(Http2FrameType type, int withLength, byte withFlags, int withStreamId) + { + var frame = await ReceiveFrameAsync(); + + Assert.Equal(type, frame.Type); + Assert.Equal(withLength, frame.Length); + Assert.Equal(withFlags, frame.Flags); + Assert.Equal(withStreamId, frame.StreamId); + + return frame; + } + + private Task StopConnectionAsync(int expectedLastStreamId, bool ignoreNonGoAwayFrames) + { + _inputPipe.Writer.Complete(); + + return WaitForConnectionStopAsync(expectedLastStreamId, ignoreNonGoAwayFrames); + } + + private Task WaitForConnectionStopAsync(int expectedLastStreamId, bool ignoreNonGoAwayFrames) + { + return WaitForConnectionErrorAsync(expectedLastStreamId, Http2ErrorCode.NO_ERROR, ignoreNonGoAwayFrames); + } + + private async Task WaitForConnectionErrorAsync(int expectedLastStreamId, Http2ErrorCode expectedErrorCode, bool ignoreNonGoAwayFrames) + { + var frame = await ReceiveFrameAsync(); + + if (ignoreNonGoAwayFrames) + { + while (frame.Type != Http2FrameType.GOAWAY) + { + frame = await ReceiveFrameAsync(); + } + } + + Assert.Equal(Http2FrameType.GOAWAY, frame.Type); + Assert.Equal(8, frame.Length); + Assert.Equal(0, frame.Flags); + Assert.Equal(0, frame.StreamId); + Assert.Equal(expectedLastStreamId, frame.GoAwayLastStreamId); + Assert.Equal(expectedErrorCode, frame.GoAwayErrorCode); + + await _connectionTask; + _inputPipe.Writer.Complete(); + } + + private async Task WaitForStreamErrorAsync(int expectedStreamId, Http2ErrorCode expectedErrorCode, bool ignoreNonRstStreamFrames) + { + var frame = await ReceiveFrameAsync(); + + if (ignoreNonRstStreamFrames) + { + while (frame.Type != Http2FrameType.RST_STREAM) + { + frame = await ReceiveFrameAsync(); + } + } + + Assert.Equal(Http2FrameType.RST_STREAM, frame.Type); + Assert.Equal(4, frame.Length); + Assert.Equal(0, frame.Flags); + Assert.Equal(expectedStreamId, frame.StreamId); + Assert.Equal(expectedErrorCode, frame.RstStreamErrorCode); + } + + private void VerifyDecodedRequestHeaders(IEnumerable> expectedHeaders) + { + foreach (var header in expectedHeaders) + { + Assert.True(_receivedHeaders.TryGetValue(header.Key, out var value), header.Key); + Assert.Equal(header.Value, value, ignoreCase: true); + } + } + } +} diff --git a/test/Kestrel.Core.Tests/HuffmanTests.cs b/test/Kestrel.Core.Tests/HuffmanTests.cs new file mode 100644 index 0000000000..7e0375c5da --- /dev/null +++ b/test/Kestrel.Core.Tests/HuffmanTests.cs @@ -0,0 +1,329 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class HuffmanTests + { + [Fact] + public void HuffmanDecodeString() + { + // h e.......e l........l l o.......o + var encodedHello = new byte[] { 0b100111_00, 0b101_10100, 0b0_101000_0, 0b0111_1111 }; + + Assert.Equal("hello", Huffman.Decode(encodedHello, 0, encodedHello.Length)); + + var encodedHeader = new byte[] + { + 0xb6, 0xb9, 0xac, 0x1c, 0x85, 0x58, 0xd5, 0x20, 0xa4, 0xb6, 0xc2, 0xad, 0x61, 0x7b, 0x5a, 0x54, 0x25, 0x1f + }; + + Assert.Equal("upgrade-insecure-requests", Huffman.Decode(encodedHeader, 0, encodedHeader.Length)); + + encodedHeader = new byte[] + { + // "t + 0xfe, 0x53 + }; + + Assert.Equal("\"t", Huffman.Decode(encodedHeader, 0, encodedHeader.Length)); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanEncode(int code, uint expectedEncoded, int expectedBitLength) + { + var (encoded, bitLength) = Huffman.Encode(code); + Assert.Equal(expectedEncoded, encoded); + Assert.Equal(expectedBitLength, bitLength); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanDecode(int code, uint encoded, int bitLength) + { + Assert.Equal(code, Huffman.Decode(encoded, out var decodedBits)); + Assert.Equal(bitLength, decodedBits); + } + + [Theory] + [MemberData(nameof(HuffmanData))] + public void HuffmanEncodeDecode(int code, uint encoded, int bitLength) + { + Assert.Equal(code, Huffman.Decode(Huffman.Encode(code).encoded, out var decodedBits)); + Assert.Equal(bitLength, decodedBits); + } + + public static TheoryData HuffmanData + { + get + { + var data = new TheoryData(); + + data.Add(0, 0b11111111_11000000_00000000_00000000, 13); + data.Add(1, 0b11111111_11111111_10110000_00000000, 23); + data.Add(2, 0b11111111_11111111_11111110_00100000, 28); + data.Add(3, 0b11111111_11111111_11111110_00110000, 28); + data.Add(4, 0b11111111_11111111_11111110_01000000, 28); + data.Add(5, 0b11111111_11111111_11111110_01010000, 28); + data.Add(6, 0b11111111_11111111_11111110_01100000, 28); + data.Add(7, 0b11111111_11111111_11111110_01110000, 28); + data.Add(8, 0b11111111_11111111_11111110_10000000, 28); + data.Add(9, 0b11111111_11111111_11101010_00000000, 24); + data.Add(10, 0b11111111_11111111_11111111_11110000, 30); + data.Add(11, 0b11111111_11111111_11111110_10010000, 28); + data.Add(12, 0b11111111_11111111_11111110_10100000, 28); + data.Add(13, 0b11111111_11111111_11111111_11110100, 30); + data.Add(14, 0b11111111_11111111_11111110_10110000, 28); + data.Add(15, 0b11111111_11111111_11111110_11000000, 28); + data.Add(16, 0b11111111_11111111_11111110_11010000, 28); + data.Add(17, 0b11111111_11111111_11111110_11100000, 28); + data.Add(18, 0b11111111_11111111_11111110_11110000, 28); + data.Add(19, 0b11111111_11111111_11111111_00000000, 28); + data.Add(20, 0b11111111_11111111_11111111_00010000, 28); + data.Add(21, 0b11111111_11111111_11111111_00100000, 28); + data.Add(22, 0b11111111_11111111_11111111_11111000, 30); + data.Add(23, 0b11111111_11111111_11111111_00110000, 28); + data.Add(24, 0b11111111_11111111_11111111_01000000, 28); + data.Add(25, 0b11111111_11111111_11111111_01010000, 28); + data.Add(26, 0b11111111_11111111_11111111_01100000, 28); + data.Add(27, 0b11111111_11111111_11111111_01110000, 28); + data.Add(28, 0b11111111_11111111_11111111_10000000, 28); + data.Add(29, 0b11111111_11111111_11111111_10010000, 28); + data.Add(30, 0b11111111_11111111_11111111_10100000, 28); + data.Add(31, 0b11111111_11111111_11111111_10110000, 28); + data.Add(32, 0b01010000_00000000_00000000_00000000, 6); + data.Add(33, 0b11111110_00000000_00000000_00000000, 10); + data.Add(34, 0b11111110_01000000_00000000_00000000, 10); + data.Add(35, 0b11111111_10100000_00000000_00000000, 12); + data.Add(36, 0b11111111_11001000_00000000_00000000, 13); + data.Add(37, 0b01010100_00000000_00000000_00000000, 6); + data.Add(38, 0b11111000_00000000_00000000_00000000, 8); + data.Add(39, 0b11111111_01000000_00000000_00000000, 11); + data.Add(40, 0b11111110_10000000_00000000_00000000, 10); + data.Add(41, 0b11111110_11000000_00000000_00000000, 10); + data.Add(42, 0b11111001_00000000_00000000_00000000, 8); + data.Add(43, 0b11111111_01100000_00000000_00000000, 11); + data.Add(44, 0b11111010_00000000_00000000_00000000, 8); + data.Add(45, 0b01011000_00000000_00000000_00000000, 6); + data.Add(46, 0b01011100_00000000_00000000_00000000, 6); + data.Add(47, 0b01100000_00000000_00000000_00000000, 6); + data.Add(48, 0b00000000_00000000_00000000_00000000, 5); + data.Add(49, 0b00001000_00000000_00000000_00000000, 5); + data.Add(50, 0b00010000_00000000_00000000_00000000, 5); + data.Add(51, 0b01100100_00000000_00000000_00000000, 6); + data.Add(52, 0b01101000_00000000_00000000_00000000, 6); + data.Add(53, 0b01101100_00000000_00000000_00000000, 6); + data.Add(54, 0b01110000_00000000_00000000_00000000, 6); + data.Add(55, 0b01110100_00000000_00000000_00000000, 6); + data.Add(56, 0b01111000_00000000_00000000_00000000, 6); + data.Add(57, 0b01111100_00000000_00000000_00000000, 6); + data.Add(58, 0b10111000_00000000_00000000_00000000, 7); + data.Add(59, 0b11111011_00000000_00000000_00000000, 8); + data.Add(60, 0b11111111_11111000_00000000_00000000, 15); + data.Add(61, 0b10000000_00000000_00000000_00000000, 6); + data.Add(62, 0b11111111_10110000_00000000_00000000, 12); + data.Add(63, 0b11111111_00000000_00000000_00000000, 10); + data.Add(64, 0b11111111_11010000_00000000_00000000, 13); + data.Add(65, 0b10000100_00000000_00000000_00000000, 6); + data.Add(66, 0b10111010_00000000_00000000_00000000, 7); + data.Add(67, 0b10111100_00000000_00000000_00000000, 7); + data.Add(68, 0b10111110_00000000_00000000_00000000, 7); + data.Add(69, 0b11000000_00000000_00000000_00000000, 7); + data.Add(70, 0b11000010_00000000_00000000_00000000, 7); + data.Add(71, 0b11000100_00000000_00000000_00000000, 7); + data.Add(72, 0b11000110_00000000_00000000_00000000, 7); + data.Add(73, 0b11001000_00000000_00000000_00000000, 7); + data.Add(74, 0b11001010_00000000_00000000_00000000, 7); + data.Add(75, 0b11001100_00000000_00000000_00000000, 7); + data.Add(76, 0b11001110_00000000_00000000_00000000, 7); + data.Add(77, 0b11010000_00000000_00000000_00000000, 7); + data.Add(78, 0b11010010_00000000_00000000_00000000, 7); + data.Add(79, 0b11010100_00000000_00000000_00000000, 7); + data.Add(80, 0b11010110_00000000_00000000_00000000, 7); + data.Add(81, 0b11011000_00000000_00000000_00000000, 7); + data.Add(82, 0b11011010_00000000_00000000_00000000, 7); + data.Add(83, 0b11011100_00000000_00000000_00000000, 7); + data.Add(84, 0b11011110_00000000_00000000_00000000, 7); + data.Add(85, 0b11100000_00000000_00000000_00000000, 7); + data.Add(86, 0b11100010_00000000_00000000_00000000, 7); + data.Add(87, 0b11100100_00000000_00000000_00000000, 7); + data.Add(88, 0b11111100_00000000_00000000_00000000, 8); + data.Add(89, 0b11100110_00000000_00000000_00000000, 7); + data.Add(90, 0b11111101_00000000_00000000_00000000, 8); + data.Add(91, 0b11111111_11011000_00000000_00000000, 13); + data.Add(92, 0b11111111_11111110_00000000_00000000, 19); + data.Add(93, 0b11111111_11100000_00000000_00000000, 13); + data.Add(94, 0b11111111_11110000_00000000_00000000, 14); + data.Add(95, 0b10001000_00000000_00000000_00000000, 6); + data.Add(96, 0b11111111_11111010_00000000_00000000, 15); + data.Add(97, 0b00011000_00000000_00000000_00000000, 5); + data.Add(98, 0b10001100_00000000_00000000_00000000, 6); + data.Add(99, 0b00100000_00000000_00000000_00000000, 5); + data.Add(100, 0b10010000_00000000_00000000_00000000, 6); + data.Add(101, 0b00101000_00000000_00000000_00000000, 5); + data.Add(102, 0b10010100_00000000_00000000_00000000, 6); + data.Add(103, 0b10011000_00000000_00000000_00000000, 6); + data.Add(104, 0b10011100_00000000_00000000_00000000, 6); + data.Add(105, 0b00110000_00000000_00000000_00000000, 5); + data.Add(106, 0b11101000_00000000_00000000_00000000, 7); + data.Add(107, 0b11101010_00000000_00000000_00000000, 7); + data.Add(108, 0b10100000_00000000_00000000_00000000, 6); + data.Add(109, 0b10100100_00000000_00000000_00000000, 6); + data.Add(110, 0b10101000_00000000_00000000_00000000, 6); + data.Add(111, 0b00111000_00000000_00000000_00000000, 5); + data.Add(112, 0b10101100_00000000_00000000_00000000, 6); + data.Add(113, 0b11101100_00000000_00000000_00000000, 7); + data.Add(114, 0b10110000_00000000_00000000_00000000, 6); + data.Add(115, 0b01000000_00000000_00000000_00000000, 5); + data.Add(116, 0b01001000_00000000_00000000_00000000, 5); + data.Add(117, 0b10110100_00000000_00000000_00000000, 6); + data.Add(118, 0b11101110_00000000_00000000_00000000, 7); + data.Add(119, 0b11110000_00000000_00000000_00000000, 7); + data.Add(120, 0b11110010_00000000_00000000_00000000, 7); + data.Add(121, 0b11110100_00000000_00000000_00000000, 7); + data.Add(122, 0b11110110_00000000_00000000_00000000, 7); + data.Add(123, 0b11111111_11111100_00000000_00000000, 15); + data.Add(124, 0b11111111_10000000_00000000_00000000, 11); + data.Add(125, 0b11111111_11110100_00000000_00000000, 14); + data.Add(126, 0b11111111_11101000_00000000_00000000, 13); + data.Add(127, 0b11111111_11111111_11111111_11000000, 28); + data.Add(128, 0b11111111_11111110_01100000_00000000, 20); + data.Add(129, 0b11111111_11111111_01001000_00000000, 22); + data.Add(130, 0b11111111_11111110_01110000_00000000, 20); + data.Add(131, 0b11111111_11111110_10000000_00000000, 20); + data.Add(132, 0b11111111_11111111_01001100_00000000, 22); + data.Add(133, 0b11111111_11111111_01010000_00000000, 22); + data.Add(134, 0b11111111_11111111_01010100_00000000, 22); + data.Add(135, 0b11111111_11111111_10110010_00000000, 23); + data.Add(136, 0b11111111_11111111_01011000_00000000, 22); + data.Add(137, 0b11111111_11111111_10110100_00000000, 23); + data.Add(138, 0b11111111_11111111_10110110_00000000, 23); + data.Add(139, 0b11111111_11111111_10111000_00000000, 23); + data.Add(140, 0b11111111_11111111_10111010_00000000, 23); + data.Add(141, 0b11111111_11111111_10111100_00000000, 23); + data.Add(142, 0b11111111_11111111_11101011_00000000, 24); + data.Add(143, 0b11111111_11111111_10111110_00000000, 23); + data.Add(144, 0b11111111_11111111_11101100_00000000, 24); + data.Add(145, 0b11111111_11111111_11101101_00000000, 24); + data.Add(146, 0b11111111_11111111_01011100_00000000, 22); + data.Add(147, 0b11111111_11111111_11000000_00000000, 23); + data.Add(148, 0b11111111_11111111_11101110_00000000, 24); + data.Add(149, 0b11111111_11111111_11000010_00000000, 23); + data.Add(150, 0b11111111_11111111_11000100_00000000, 23); + data.Add(151, 0b11111111_11111111_11000110_00000000, 23); + data.Add(152, 0b11111111_11111111_11001000_00000000, 23); + data.Add(153, 0b11111111_11111110_11100000_00000000, 21); + data.Add(154, 0b11111111_11111111_01100000_00000000, 22); + data.Add(155, 0b11111111_11111111_11001010_00000000, 23); + data.Add(156, 0b11111111_11111111_01100100_00000000, 22); + data.Add(157, 0b11111111_11111111_11001100_00000000, 23); + data.Add(158, 0b11111111_11111111_11001110_00000000, 23); + data.Add(159, 0b11111111_11111111_11101111_00000000, 24); + data.Add(160, 0b11111111_11111111_01101000_00000000, 22); + data.Add(161, 0b11111111_11111110_11101000_00000000, 21); + data.Add(162, 0b11111111_11111110_10010000_00000000, 20); + data.Add(163, 0b11111111_11111111_01101100_00000000, 22); + data.Add(164, 0b11111111_11111111_01110000_00000000, 22); + data.Add(165, 0b11111111_11111111_11010000_00000000, 23); + data.Add(166, 0b11111111_11111111_11010010_00000000, 23); + data.Add(167, 0b11111111_11111110_11110000_00000000, 21); + data.Add(168, 0b11111111_11111111_11010100_00000000, 23); + data.Add(169, 0b11111111_11111111_01110100_00000000, 22); + data.Add(170, 0b11111111_11111111_01111000_00000000, 22); + data.Add(171, 0b11111111_11111111_11110000_00000000, 24); + data.Add(172, 0b11111111_11111110_11111000_00000000, 21); + data.Add(173, 0b11111111_11111111_01111100_00000000, 22); + data.Add(174, 0b11111111_11111111_11010110_00000000, 23); + data.Add(175, 0b11111111_11111111_11011000_00000000, 23); + data.Add(176, 0b11111111_11111111_00000000_00000000, 21); + data.Add(177, 0b11111111_11111111_00001000_00000000, 21); + data.Add(178, 0b11111111_11111111_10000000_00000000, 22); + data.Add(179, 0b11111111_11111111_00010000_00000000, 21); + data.Add(180, 0b11111111_11111111_11011010_00000000, 23); + data.Add(181, 0b11111111_11111111_10000100_00000000, 22); + data.Add(182, 0b11111111_11111111_11011100_00000000, 23); + data.Add(183, 0b11111111_11111111_11011110_00000000, 23); + data.Add(184, 0b11111111_11111110_10100000_00000000, 20); + data.Add(185, 0b11111111_11111111_10001000_00000000, 22); + data.Add(186, 0b11111111_11111111_10001100_00000000, 22); + data.Add(187, 0b11111111_11111111_10010000_00000000, 22); + data.Add(188, 0b11111111_11111111_11100000_00000000, 23); + data.Add(189, 0b11111111_11111111_10010100_00000000, 22); + data.Add(190, 0b11111111_11111111_10011000_00000000, 22); + data.Add(191, 0b11111111_11111111_11100010_00000000, 23); + data.Add(192, 0b11111111_11111111_11111000_00000000, 26); + data.Add(193, 0b11111111_11111111_11111000_01000000, 26); + data.Add(194, 0b11111111_11111110_10110000_00000000, 20); + data.Add(195, 0b11111111_11111110_00100000_00000000, 19); + data.Add(196, 0b11111111_11111111_10011100_00000000, 22); + data.Add(197, 0b11111111_11111111_11100100_00000000, 23); + data.Add(198, 0b11111111_11111111_10100000_00000000, 22); + data.Add(199, 0b11111111_11111111_11110110_00000000, 25); + data.Add(200, 0b11111111_11111111_11111000_10000000, 26); + data.Add(201, 0b11111111_11111111_11111000_11000000, 26); + data.Add(202, 0b11111111_11111111_11111001_00000000, 26); + data.Add(203, 0b11111111_11111111_11111011_11000000, 27); + data.Add(204, 0b11111111_11111111_11111011_11100000, 27); + data.Add(205, 0b11111111_11111111_11111001_01000000, 26); + data.Add(206, 0b11111111_11111111_11110001_00000000, 24); + data.Add(207, 0b11111111_11111111_11110110_10000000, 25); + data.Add(208, 0b11111111_11111110_01000000_00000000, 19); + data.Add(209, 0b11111111_11111111_00011000_00000000, 21); + data.Add(210, 0b11111111_11111111_11111001_10000000, 26); + data.Add(211, 0b11111111_11111111_11111100_00000000, 27); + data.Add(212, 0b11111111_11111111_11111100_00100000, 27); + data.Add(213, 0b11111111_11111111_11111001_11000000, 26); + data.Add(214, 0b11111111_11111111_11111100_01000000, 27); + data.Add(215, 0b11111111_11111111_11110010_00000000, 24); + data.Add(216, 0b11111111_11111111_00100000_00000000, 21); + data.Add(217, 0b11111111_11111111_00101000_00000000, 21); + data.Add(218, 0b11111111_11111111_11111010_00000000, 26); + data.Add(219, 0b11111111_11111111_11111010_01000000, 26); + data.Add(220, 0b11111111_11111111_11111111_11010000, 28); + data.Add(221, 0b11111111_11111111_11111100_01100000, 27); + data.Add(222, 0b11111111_11111111_11111100_10000000, 27); + data.Add(223, 0b11111111_11111111_11111100_10100000, 27); + data.Add(224, 0b11111111_11111110_11000000_00000000, 20); + data.Add(225, 0b11111111_11111111_11110011_00000000, 24); + data.Add(226, 0b11111111_11111110_11010000_00000000, 20); + data.Add(227, 0b11111111_11111111_00110000_00000000, 21); + data.Add(228, 0b11111111_11111111_10100100_00000000, 22); + data.Add(229, 0b11111111_11111111_00111000_00000000, 21); + data.Add(230, 0b11111111_11111111_01000000_00000000, 21); + data.Add(231, 0b11111111_11111111_11100110_00000000, 23); + data.Add(232, 0b11111111_11111111_10101000_00000000, 22); + data.Add(233, 0b11111111_11111111_10101100_00000000, 22); + data.Add(234, 0b11111111_11111111_11110111_00000000, 25); + data.Add(235, 0b11111111_11111111_11110111_10000000, 25); + data.Add(236, 0b11111111_11111111_11110100_00000000, 24); + data.Add(237, 0b11111111_11111111_11110101_00000000, 24); + data.Add(238, 0b11111111_11111111_11111010_10000000, 26); + data.Add(239, 0b11111111_11111111_11101000_00000000, 23); + data.Add(240, 0b11111111_11111111_11111010_11000000, 26); + data.Add(241, 0b11111111_11111111_11111100_11000000, 27); + data.Add(242, 0b11111111_11111111_11111011_00000000, 26); + data.Add(243, 0b11111111_11111111_11111011_01000000, 26); + data.Add(244, 0b11111111_11111111_11111100_11100000, 27); + data.Add(245, 0b11111111_11111111_11111101_00000000, 27); + data.Add(246, 0b11111111_11111111_11111101_00100000, 27); + data.Add(247, 0b11111111_11111111_11111101_01000000, 27); + data.Add(248, 0b11111111_11111111_11111101_01100000, 27); + data.Add(249, 0b11111111_11111111_11111111_11100000, 28); + data.Add(250, 0b11111111_11111111_11111101_10000000, 27); + data.Add(251, 0b11111111_11111111_11111101_10100000, 27); + data.Add(252, 0b11111111_11111111_11111101_11000000, 27); + data.Add(253, 0b11111111_11111111_11111101_11100000, 27); + data.Add(254, 0b11111111_11111111_11111110_00000000, 27); + data.Add(255, 0b11111111_11111111_11111011_10000000, 26); + data.Add(256, 0b11111111_11111111_11111111_11111100, 30); + + return data; + } + } + } +} diff --git a/test/Kestrel.Core.Tests/IntegerDecoderTests.cs b/test/Kestrel.Core.Tests/IntegerDecoderTests.cs new file mode 100644 index 0000000000..c9e259b0ff --- /dev/null +++ b/test/Kestrel.Core.Tests/IntegerDecoderTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class IntegerDecoderTests + { + [Theory] + [MemberData(nameof(IntegerData))] + public void IntegerDecode(int i, int prefixLength, byte[] octets) + { + var decoder = new IntegerDecoder(); + var result = decoder.BeginDecode(octets[0], prefixLength); + + if (octets.Length == 1) + { + Assert.True(result); + } + else + { + var j = 1; + + for (; j < octets.Length - 1; j++) + { + Assert.False(decoder.Decode(octets[j])); + } + + Assert.True(decoder.Decode(octets[j])); + } + + Assert.Equal(i, decoder.Value); + } + + public static TheoryData IntegerData + { + get + { + var data = new TheoryData(); + + data.Add(10, 5, new byte[] { 10 }); + data.Add(1337, 5, new byte[] { 0x1f, 0x9a, 0x0a }); + data.Add(42, 8, new byte[] { 42 }); + + return data; + } + } + } +} diff --git a/test/Kestrel.Core.Tests/IntegerEncoderTests.cs b/test/Kestrel.Core.Tests/IntegerEncoderTests.cs new file mode 100644 index 0000000000..2c4811c81c --- /dev/null +++ b/test/Kestrel.Core.Tests/IntegerEncoderTests.cs @@ -0,0 +1,37 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2.HPack; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class IntegerEncoderTests + { + [Theory] + [MemberData(nameof(IntegerData))] + public void IntegerEncode(int i, int prefixLength, byte[] expectedOctets) + { + var buffer = new byte[expectedOctets.Length]; + + Assert.True(IntegerEncoder.Encode(i, prefixLength, buffer, out var octets)); + Assert.Equal(expectedOctets.Length, octets); + Assert.Equal(expectedOctets, buffer); + } + + public static TheoryData IntegerData + { + get + { + var data = new TheoryData(); + + data.Add(10, 5, new byte[] { 10 }); + data.Add(1337, 5, new byte[] { 0x1f, 0x9a, 0x0a }); + data.Add(42, 8, new byte[] { 42 }); + + return data; + } + } + } +} diff --git a/test/Kestrel.FunctionalTests/GeneratedCodeTests.cs b/test/Kestrel.FunctionalTests/GeneratedCodeTests.cs index 5a6b03781c..548a028a8a 100644 --- a/test/Kestrel.FunctionalTests/GeneratedCodeTests.cs +++ b/test/Kestrel.FunctionalTests/GeneratedCodeTests.cs @@ -14,26 +14,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { const string frameHeadersGeneratedPath = "../../../../../src/Kestrel.Core/Internal/Http/FrameHeaders.Generated.cs"; const string frameGeneratedPath = "../../../../../src/Kestrel.Core/Internal/Http/Frame.Generated.cs"; + const string http2StreamGeneratedPath = "../../../../../src/Kestrel.Core/Internal/Http2/Http2Stream.Generated.cs"; const string httpUtilitiesGeneratedPath = "../../../../../src/Kestrel.Core/Internal/Infrastructure/HttpUtilities.Generated.cs"; var testFrameHeadersGeneratedPath = Path.GetTempFileName(); var testFrameGeneratedPath = Path.GetTempFileName(); + var testHttp2StreamGeneratedPath = Path.GetTempFileName(); var testHttpUtilitiesGeneratedPath = Path.GetTempFileName(); try { var currentFrameHeadersGenerated = File.ReadAllText(frameHeadersGeneratedPath); var currentFrameGenerated = File.ReadAllText(frameGeneratedPath); + var currentHttp2StreamGenerated = File.ReadAllText(http2StreamGeneratedPath); var currentHttpUtilitiesGenerated = File.ReadAllText(httpUtilitiesGeneratedPath); - CodeGenerator.Program.Run(testFrameHeadersGeneratedPath, testFrameGeneratedPath, testHttpUtilitiesGeneratedPath); + CodeGenerator.Program.Run(testFrameHeadersGeneratedPath, testFrameGeneratedPath, testHttp2StreamGeneratedPath, testHttpUtilitiesGeneratedPath); var testFrameHeadersGenerated = File.ReadAllText(testFrameHeadersGeneratedPath); var testFrameGenerated = File.ReadAllText(testFrameGeneratedPath); + var testHttp2StreamGenerated = File.ReadAllText(testHttp2StreamGeneratedPath); var testHttpUtilitiesGenerated = File.ReadAllText(testHttpUtilitiesGeneratedPath); Assert.Equal(currentFrameHeadersGenerated, testFrameHeadersGenerated, ignoreLineEndingDifferences: true); Assert.Equal(currentFrameGenerated, testFrameGenerated, ignoreLineEndingDifferences: true); + Assert.Equal(currentHttp2StreamGenerated, testHttp2StreamGenerated, ignoreLineEndingDifferences: true); Assert.Equal(currentHttpUtilitiesGenerated, testHttpUtilitiesGenerated, ignoreLineEndingDifferences: true); } @@ -41,6 +46,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { File.Delete(testFrameHeadersGeneratedPath); File.Delete(testFrameGeneratedPath); + File.Delete(testHttp2StreamGeneratedPath); File.Delete(testHttpUtilitiesGeneratedPath); } } diff --git a/tools/CodeGenerator/CodeGenerator.csproj b/tools/CodeGenerator/CodeGenerator.csproj index 143a966654..91cec270ff 100644 --- a/tools/CodeGenerator/CodeGenerator.csproj +++ b/tools/CodeGenerator/CodeGenerator.csproj @@ -20,7 +20,7 @@ $(MSBuildThisFileDirectory)..\..\src\Kestrel.Core - Internal\Http\FrameHeaders.Generated.cs Internal\Http\Frame.Generated.cs Internal\Infrastructure\HttpUtilities.Generated.cs + Internal/Http/FrameHeaders.Generated.cs Internal/Http/Frame.Generated.cs Internal/Http2/Http2Stream.Generated.cs Internal/Infrastructure/HttpUtilities.Generated.cs diff --git a/tools/CodeGenerator/FrameFeatureCollection.cs b/tools/CodeGenerator/FrameFeatureCollection.cs index d05139d919..b32dacbb72 100644 --- a/tools/CodeGenerator/FrameFeatureCollection.cs +++ b/tools/CodeGenerator/FrameFeatureCollection.cs @@ -19,8 +19,10 @@ namespace CodeGenerator return values.Select(formatter).Aggregate((a, b) => a + b); } - public static string GeneratedFile() + public static string GeneratedFile(string className, string namespaceSuffix, IEnumerable additionalFeatures = null) { + additionalFeatures = additionalFeatures ?? new Type[] { }; + var alwaysFeatures = new[] { typeof(IHttpRequestFeature), @@ -57,7 +59,7 @@ namespace CodeGenerator typeof(IHttpSendFileFeature), }; - var allFeatures = alwaysFeatures.Concat(commonFeatures).Concat(sometimesFeatures).Concat(rareFeatures); + var allFeatures = alwaysFeatures.Concat(commonFeatures).Concat(sometimesFeatures).Concat(rareFeatures).Concat(additionalFeatures); // NOTE: This list MUST always match the set of feature interfaces implemented by Frame. // See also: src/Kestrel/Http/Frame.FeatureCollection.cs @@ -73,7 +75,7 @@ namespace CodeGenerator typeof(IHttpMinRequestBodyDataRateFeature), typeof(IHttpMinResponseDataRateFeature), typeof(IHttpBodyControlFeature), - }; + }.Concat(additionalFeatures); return $@"// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. @@ -81,9 +83,9 @@ namespace CodeGenerator using System; using System.Collections.Generic; -namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.{namespaceSuffix} {{ - public partial class Frame + public partial class {className} {{{Each(allFeatures, feature => $@" private static readonly Type {feature.Name}Type = typeof(global::{feature.FullName});")} {Each(allFeatures, feature => $@" diff --git a/tools/CodeGenerator/Program.cs b/tools/CodeGenerator/Program.cs index 24be8360b9..59c190a2be 100644 --- a/tools/CodeGenerator/Program.cs +++ b/tools/CodeGenerator/Program.cs @@ -3,6 +3,9 @@ using System; using System.IO; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; namespace CodeGenerator { @@ -26,15 +29,16 @@ namespace CodeGenerator return 1; } - Run(args[0], args[1], args[2]); + Run(args[0], args[1], args[2], args[3]); return 0; } - public static void Run(string knownHeadersPath, string frameFeaturesCollectionPath, string httpUtilitiesPath) + public static void Run(string knownHeadersPath, string frameFeatureCollectionPath, string http2StreamFeatureCollectionPath, string httpUtilitiesPath) { var knownHeadersContent = KnownHeaders.GeneratedFile(); - var frameFeatureCollectionContent = FrameFeatureCollection.GeneratedFile(); + var frameFeatureCollectionContent = FrameFeatureCollection.GeneratedFile(nameof(Frame), "Http"); + var http2StreamFeatureCollectionContent = FrameFeatureCollection.GeneratedFile(nameof(Http2Stream), "Http2", new[] { typeof(IHttp2StreamIdFeature) }); var httpUtilitiesContent = HttpUtilities.HttpUtilities.GeneratedFile(); var existingKnownHeaders = File.Exists(knownHeadersPath) ? File.ReadAllText(knownHeadersPath) : ""; @@ -43,10 +47,16 @@ namespace CodeGenerator File.WriteAllText(knownHeadersPath, knownHeadersContent); } - var existingFrameFeatureCollection = File.Exists(frameFeaturesCollectionPath) ? File.ReadAllText(frameFeaturesCollectionPath) : ""; + var existingFrameFeatureCollection = File.Exists(frameFeatureCollectionPath) ? File.ReadAllText(frameFeatureCollectionPath) : ""; if (!string.Equals(frameFeatureCollectionContent, existingFrameFeatureCollection)) { - File.WriteAllText(frameFeaturesCollectionPath, frameFeatureCollectionContent); + File.WriteAllText(frameFeatureCollectionPath, frameFeatureCollectionContent); + } + + var existingHttp2StreamFeatureCollection = File.Exists(http2StreamFeatureCollectionPath) ? File.ReadAllText(http2StreamFeatureCollectionPath) : ""; + if (!string.Equals(http2StreamFeatureCollectionContent, existingHttp2StreamFeatureCollection)) + { + File.WriteAllText(http2StreamFeatureCollectionPath, http2StreamFeatureCollectionContent); } var existingHttpUtilities = File.Exists(httpUtilitiesPath) ? File.ReadAllText(httpUtilitiesPath) : "";