diff --git a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
index 30f567851c..6f53a07297 100644
--- a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
+++ b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
@@ -200,6 +200,10 @@ namespace Microsoft.AspNetCore.Http.Features
bool Available { get; }
Microsoft.AspNetCore.Http.IHeaderDictionary Trailers { get; }
}
+ public partial interface IHttpResponseCompletionFeature
+ {
+ System.Threading.Tasks.Task CompleteAsync();
+ }
public partial interface IHttpResponseFeature
{
System.IO.Stream Body { get; set; }
diff --git a/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs
new file mode 100644
index 0000000000..eed45e4036
--- /dev/null
+++ b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs
@@ -0,0 +1,20 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.Http.Features
+{
+ ///
+ /// A feature to gracefully end a response.
+ ///
+ public interface IHttpResponseCompletionFeature
+ {
+ ///
+ /// Flush any remaining response headers, data, or trailers.
+ /// This may throw if the response is in an invalid state such as a Content-Length mismatch.
+ ///
+ ///
+ Task CompleteAsync();
+ }
+}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
index 82b1ebfe6b..f922de8a99 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
@@ -277,6 +277,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
protected void ResetHttp2Features()
{
_currentIHttp2StreamIdFeature = this;
+ _currentIHttpResponseCompletionFeature = this;
_currentIHttpResponseTrailersFeature = this;
}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
index b9b3e26905..f594feed0f 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
@@ -29,6 +29,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private static readonly Type IFormFeatureType = typeof(IFormFeature);
private static readonly Type IHttpUpgradeFeatureType = typeof(IHttpUpgradeFeature);
private static readonly Type IHttp2StreamIdFeatureType = typeof(IHttp2StreamIdFeature);
+ private static readonly Type IHttpResponseCompletionFeatureType = typeof(IHttpResponseCompletionFeature);
private static readonly Type IHttpResponseTrailersFeatureType = typeof(IHttpResponseTrailersFeature);
private static readonly Type IResponseCookiesFeatureType = typeof(IResponseCookiesFeature);
private static readonly Type IItemsFeatureType = typeof(IItemsFeature);
@@ -58,6 +59,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private object _currentIFormFeature;
private object _currentIHttpUpgradeFeature;
private object _currentIHttp2StreamIdFeature;
+ private object _currentIHttpResponseCompletionFeature;
private object _currentIHttpResponseTrailersFeature;
private object _currentIResponseCookiesFeature;
private object _currentIItemsFeature;
@@ -98,6 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_currentIQueryFeature = null;
_currentIFormFeature = null;
_currentIHttp2StreamIdFeature = null;
+ _currentIHttpResponseCompletionFeature = null;
_currentIHttpResponseTrailersFeature = null;
_currentIResponseCookiesFeature = null;
_currentIItemsFeature = null;
@@ -224,6 +227,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
feature = _currentIHttp2StreamIdFeature;
}
+ else if (key == IHttpResponseCompletionFeatureType)
+ {
+ feature = _currentIHttpResponseCompletionFeature;
+ }
else if (key == IHttpResponseTrailersFeatureType)
{
feature = _currentIHttpResponseTrailersFeature;
@@ -348,6 +355,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
_currentIHttp2StreamIdFeature = value;
}
+ else if (key == IHttpResponseCompletionFeatureType)
+ {
+ _currentIHttpResponseCompletionFeature = value;
+ }
else if (key == IHttpResponseTrailersFeatureType)
{
_currentIHttpResponseTrailersFeature = value;
@@ -470,6 +481,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
feature = (TFeature)_currentIHttp2StreamIdFeature;
}
+ else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
+ {
+ feature = (TFeature)_currentIHttpResponseCompletionFeature;
+ }
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
feature = (TFeature)_currentIHttpResponseTrailersFeature;
@@ -598,6 +613,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
_currentIHttp2StreamIdFeature = feature;
}
+ else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
+ {
+ _currentIHttpResponseCompletionFeature = feature;
+ }
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
_currentIHttpResponseTrailersFeature = feature;
@@ -718,6 +737,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
yield return new KeyValuePair(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature);
}
+ if (_currentIHttpResponseCompletionFeature != null)
+ {
+ yield return new KeyValuePair(IHttpResponseCompletionFeatureType, _currentIHttpResponseCompletionFeature);
+ }
if (_currentIHttpResponseTrailersFeature != null)
{
yield return new KeyValuePair(IHttpResponseTrailersFeatureType, _currentIHttpResponseTrailersFeature);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
index 41a4f2e6aa..bc053825c4 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
@@ -210,6 +210,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public bool RequestTrailersAvailable { get; set; }
public Stream RequestBody { get; set; }
public PipeReader RequestBodyPipeReader { get; set; }
+ public HttpResponseTrailers ResponseTrailers { get; set; }
private int _statusCode;
public int StatusCode
@@ -287,7 +288,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public bool HasResponseStarted => _requestProcessingStatus >= RequestProcessingStatus.HeadersCommitted;
- public bool HasFlushedHeaders => _requestProcessingStatus == RequestProcessingStatus.HeadersFlushed;
+ public bool HasFlushedHeaders => _requestProcessingStatus >= RequestProcessingStatus.HeadersFlushed;
+
+ public bool HasResponseCompleted => _requestProcessingStatus == RequestProcessingStatus.ResponseCompleted;
protected HttpRequestHeaders HttpRequestHeaders { get; }
@@ -632,9 +635,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
// Run the application code for this request
await application.ProcessRequestAsync(context);
- if (!_connectionAborted)
+ // Trigger OnStarting if it hasn't been called yet and the app hasn't
+ // already failed. If an OnStarting callback throws we can go through
+ // our normal error handling in ProduceEnd.
+ // https://github.com/aspnet/KestrelHttpServer/issues/43
+ if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
{
- VerifyResponseContentLength();
+ await FireOnStarting();
+ }
+
+ if (!_connectionAborted && !VerifyResponseContentLength(out var lengthException))
+ {
+ ReportApplicationError(lengthException);
}
}
catch (BadHttpRequestException ex)
@@ -652,15 +664,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
KestrelEventSource.Log.RequestStop(this);
- // Trigger OnStarting if it hasn't been called yet and the app hasn't
- // already failed. If an OnStarting callback throws we can go through
- // our normal error handling in ProduceEnd.
- // https://github.com/aspnet/KestrelHttpServer/issues/43
- if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
- {
- await FireOnStarting();
- }
-
// At this point all user code that needs use to the request or response streams has completed.
// Using these streams in the OnCompleted callback is not allowed.
StopBodies();
@@ -898,7 +901,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
- protected void VerifyResponseContentLength()
+ protected bool VerifyResponseContentLength(out Exception ex)
{
var responseHeaders = HttpResponseHeaders;
@@ -915,9 +918,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_keepAlive = false;
}
- ReportApplicationError(new InvalidOperationException(
- CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value)));
+ ex = new InvalidOperationException(
+ CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value));
+ return false;
}
+
+ ex = null;
+ return true;
}
public void ProduceContinue()
@@ -1045,6 +1052,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private Task WriteSuffix()
{
+ if (HasResponseCompleted)
+ {
+ return Task.CompletedTask;
+ }
+
// _autoChunk should be checked after we are sure ProduceStart() has been called
// since ProduceStart() may set _autoChunk to true.
if (_autoChunk || _httpVersion == Http.HttpVersion.Http2)
@@ -1064,7 +1076,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
if (!HasFlushedHeaders)
{
- _requestProcessingStatus = RequestProcessingStatus.HeadersFlushed;
+ _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
return FlushAsyncInternal();
}
@@ -1080,6 +1092,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
await Output.WriteStreamSuffixAsync();
+ _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
+
if (_keepAlive)
{
Log.ConnectionKeepAlive(ConnectionId);
@@ -1244,6 +1258,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
var responseHeaders = HttpResponseHeaders;
responseHeaders.Reset();
+ ResponseTrailers?.Reset();
var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues();
responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
index 61832dc34b..6e27fb5dc8 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
@@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ParsingHeaders,
AppStarted,
HeadersCommitted,
- HeadersFlushed
+ HeadersFlushed,
+ ResponseCompleted
}
}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
index 4f481850d7..5f0d80a372 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
@@ -151,7 +151,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
// 2. There is no trailing HEADERS frame.
Http2HeadersFrameFlags http2HeadersFrame;
- if (appCompleted && !_startedWritingDataFrames && (_stream.Trailers == null || _stream.Trailers.Count == 0))
+ if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0))
{
_streamEnded = true;
http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM;
@@ -313,7 +313,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
readResult = await _dataPipe.Reader.ReadAsync();
- if (readResult.IsCompleted && _stream.Trailers?.Count > 0)
+ if (readResult.IsCompleted && _stream.ResponseTrailers?.Count > 0)
{
// Output is ending and there are trailers to write
// Write any remaining content then write trailers
@@ -322,7 +322,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
flushResult = await _frameWriter.WriteDataAsync(_streamId, _flowControl, readResult.Buffer, endStream: false);
}
- flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.Trailers);
+ _stream.ResponseTrailers.SetReadOnly();
+ flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.ResponseTrailers);
}
else if (readResult.IsCompleted && _streamEnded)
{
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
index 7187fc8462..fb27e387d3 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
@@ -11,21 +12,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
internal partial class Http2Stream : IHttp2StreamIdFeature,
IHttpMinRequestBodyDataRateFeature,
+ IHttpResponseCompletionFeature,
IHttpResponseTrailersFeature
{
- internal HttpResponseTrailers Trailers { get; set; }
private IHeaderDictionary _userTrailers;
IHeaderDictionary IHttpResponseTrailersFeature.Trailers
{
get
{
- if (Trailers == null)
+ if (ResponseTrailers == null)
{
- Trailers = new HttpResponseTrailers();
+ ResponseTrailers = new HttpResponseTrailers();
+ if (HasResponseCompleted)
+ {
+ ResponseTrailers.SetReadOnly();
+ }
}
- return _userTrailers ?? Trailers;
+ return _userTrailers ?? ResponseTrailers;
}
set
{
@@ -48,5 +53,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
MinRequestBodyDataRate = value;
}
}
+
+ async Task IHttpResponseCompletionFeature.CompleteAsync()
+ {
+ // Finalize headers
+ if (!HasResponseStarted)
+ {
+ await FireOnStarting();
+ }
+
+ // Flush headers, body, trailers...
+ if (!HasResponseCompleted)
+ {
+ if (!VerifyResponseContentLength(out var lengthException))
+ {
+ throw lengthException;
+ }
+
+ await ProduceEnd();
+ }
+ }
}
}
diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
index 09c9dafae6..7bdaa893a9 100644
--- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
+++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
@@ -1839,6 +1839,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
+ [Fact]
+ public async Task ResponseTrailers_WithExeption500_Cleared()
+ {
+ await InitializeConnectionAsync(context =>
+ {
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+ throw new NotImplementedException("Test Exception");
+ });
+
+ await StartStreamAsync(1, _browserRequestHeaders, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM | Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
[Fact]
public async Task ResponseTrailers_WithData_Sent()
{
@@ -3307,5 +3333,779 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Contains(TestSink.Writes, w => w.EventId.Id == 13 && w.LogLevel == LogLevel.Error
&& w.Exception is ConnectionAbortedException && w.Exception.InnerException == expectedException);
}
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_SendsHeadersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders["content-length"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_SendsHeadersAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders["content-length"]);
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAnd500()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ context.Response.ContentLength = 25;
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout());
+ Assert.Equal(CoreStrings.FormatTooFewBytesWritten(0, 25), ex.Message);
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully);
+ Assert.False(context.Response.Headers.IsReadOnly);
+ Assert.False(context.Features.Get().Trailers.IsReadOnly);
+
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_SendsBodyWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ await ExpectAsync(Http2FrameType.DATA,
+ withLength: 0,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task CompleteAsync_WriteAfterComplete_Throws()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
+ Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_WriteAgainAfterComplete_Throws()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World").DefaultTimeout();
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
+ Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ await ExpectAsync(Http2FrameType.DATA,
+ withLength: 0,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterPipeWrite_WithTrailers_SendsBodyAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ var buffer = context.Response.BodyWriter.GetMemory();
+ var length = Encoding.UTF8.GetBytes("Hello World", buffer.Span);
+ context.Response.BodyWriter.Advance(length);
+
+ Assert.False(startingTcs.Task.IsCompletedSuccessfully); // OnStarting did not get called.
+ Assert.False(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAndReset()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ context.Response.ContentLength = 25;
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout());
+ Assert.Equal(CoreStrings.FormatTooFewBytesWritten(11, 25), ex.Message);
+
+ Assert.False(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 56,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR,
+ expectedErrorMessage: CoreStrings.FormatTooFewBytesWritten(11, 25));
+
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("25", _decodedHeaders[HeaderNames.ContentLength]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task AbortAfterCompleteAsync_GETWithResponseBodyAndTrailers_ResetsAfterResponse()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // RequestAborted will no longer fire after CompleteAsync.
+ Assert.False(context.RequestAborted.CanBeCanceled);
+ context.Abort();
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task AbortAfterCompleteAsync_POSTWithResponseBodyAndTrailers_RequestBodyThrows()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "POST"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ var requestBodyTask = context.Request.BodyReader.ReadAsync();
+
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // RequestAborted will no longer fire after CompleteAsync.
+ Assert.False(context.RequestAborted.CanBeCanceled);
+ context.Abort();
+
+ await Assert.ThrowsAsync(async () => await requestBodyTask);
+ await Assert.ThrowsAsync(async () => await context.Request.BodyReader.ReadAsync());
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: false);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
}
}
diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
index 8266770719..d30a30a9eb 100644
--- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
+++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
@@ -35,6 +35,7 @@ namespace CodeGenerator
{
"IHttpUpgradeFeature",
"IHttp2StreamIdFeature",
+ "IHttpResponseCompletionFeature",
"IHttpResponseTrailersFeature",
"IResponseCookiesFeature",
"IItemsFeature",