From 3fbfba63f813de004030615546546360b565b8d5 Mon Sep 17 00:00:00 2001 From: Cesar Blum Silveira Date: Wed, 18 Oct 2017 16:31:50 -0700 Subject: [PATCH] HTTP/2: implement 100-continue (#2106) --- .../Internal/Http/Http1MessageBody.cs | 10 ----- src/Kestrel.Core/Internal/Http/MessageBody.cs | 12 +++++- .../Internal/Http2/Http2FrameWriter.cs | 12 +++++- .../Internal/Http2/Http2MessageBody.cs | 9 ++++ .../Internal/Http2/Http2Stream.cs | 2 + .../Http2ConnectionTests.cs | 41 +++++++++++++++++++ 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs index 4f623108d8..9b50aba2e3 100644 --- a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs @@ -14,7 +14,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { private readonly Http1Connection _context; - private bool _send100Continue = true; private volatile bool _canceled; private Task _pumpTask; @@ -150,15 +149,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } - private void TryProduceContinue() - { - if (_send100Continue) - { - _context.HttpResponseControl.ProduceContinue(); - _send100Continue = false; - } - } - protected void Copy(ReadableBuffer readableBuffer, WritableBuffer writableBuffer) { _context.TimeoutControl.BytesRead(readableBuffer.Length); diff --git a/src/Kestrel.Core/Internal/Http/MessageBody.cs b/src/Kestrel.Core/Internal/Http/MessageBody.cs index 8ae912435a..2fd8c12e60 100644 --- a/src/Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/MessageBody.cs @@ -6,7 +6,6 @@ using System.IO; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http @@ -18,6 +17,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private readonly HttpProtocol _context; + private bool _send100Continue = true; + protected MessageBody(HttpProtocol context) { _context = context; @@ -111,6 +112,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public abstract Task StopAsync(); + protected void TryProduceContinue() + { + if (_send100Continue) + { + _context.HttpResponseControl.ProduceContinue(); + _send100Continue = false; + } + } + private void TryInit() { if (!_context.HasStartedConsumingRequestBody) diff --git a/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs index fb2468d0c4..a34bf0ef53 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs @@ -14,6 +14,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { public class Http2FrameWriter : IHttp2FrameWriter { + // Literal Header Field without Indexing - Indexed Name (Index 8 - :status) + private static readonly byte[] _continueBytes = new byte[] { 0x08, 0x03, (byte)'1', (byte)'0', (byte)'0' }; + private readonly Http2Frame _outgoingFrame = new Http2Frame(); private readonly object _writeLock = new object(); private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); @@ -50,7 +53,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 public Task Write100ContinueAsync(int streamId) { - return Task.CompletedTask; + lock (_writeLock) + { + _outgoingFrame.PrepareHeaders(Http2HeadersFrameFlags.END_HEADERS, streamId); + _outgoingFrame.Length = _continueBytes.Length; + _continueBytes.CopyTo(_outgoingFrame.HeadersPayload); + + return WriteAsync(_outgoingFrame.Raw); + } } public void WriteResponseHeaders(int streamId, int statusCode, IHeaderDictionary headers) diff --git a/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs index 71a308a8f7..b6ad3af161 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs @@ -16,6 +16,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 _context = context; } + protected override void OnReadStarted() + { + // Produce 100-continue if no request body data for the stream has arrived yet. + if (!_context.RequestBodyStarted) + { + TryProduceContinue(); + } + } + protected override Task OnConsumeAsync() => Task.CompletedTask; public override Task StopAsync() diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs index b8f9db04ee..c9832ef03d 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs @@ -24,6 +24,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 public int StreamId => _context.StreamId; + public bool RequestBodyStarted { get; private set; } public bool EndStreamReceived { get; private set; } protected IHttp2StreamLifetimeHandler StreamLifetimeHandler => _context.StreamLifetimeHandler; @@ -84,6 +85,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 writableBuffer.Commit(); } + RequestBodyStarted = true; await writableBuffer.FlushAsync(); } diff --git a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs index 39db80b4a4..e4a875da58 100644 --- a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs +++ b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs @@ -33,6 +33,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests new KeyValuePair(":scheme", "http"), }; + private static readonly IEnumerable> _expectContinueRequestHeaders = new[] + { + new KeyValuePair(":method", "POST"), + new KeyValuePair(":path", "/"), + new KeyValuePair(":authority", "127.0.0.1"), + new KeyValuePair(":scheme", "https"), + new KeyValuePair("expect", "100-continue"), + }; + private static readonly IEnumerable> _browserRequestHeaders = new[] { new KeyValuePair(":method", "GET"), @@ -755,6 +764,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); } + [Fact] + public async Task HEADERS_Received_ContainsExpect100Continue_100ContinueSent() + { + await InitializeConnectionAsync(_echoApplication); + + await StartStreamAsync(1, _expectContinueRequestHeaders, false); + + var frame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 5, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + await SendDataAsync(1, _helloBytes, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 5, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + Assert.Equal(new byte[] { 0x08, 0x03, (byte)'1', (byte)'0', (byte)'0' }, frame.HeadersPayload.ToArray()); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + [Fact] public async Task HEADERS_Received_StreamIdZero_ConnectionError() {