Implement Http/2 CompleteAsync #10886 (#11193)

This commit is contained in:
Chris Ross 2019-06-15 09:01:45 -07:00 committed by GitHub
parent 5872814a64
commit bc5bee7477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 915 additions and 24 deletions

View File

@ -200,6 +200,10 @@ namespace Microsoft.AspNetCore.Http.Features
bool Available { get; }
Microsoft.AspNetCore.Http.IHeaderDictionary Trailers { get; }
}
public partial interface IHttpResponseCompletionFeature
{
System.Threading.Tasks.Task CompleteAsync();
}
public partial interface IHttpResponseFeature
{
System.IO.Stream Body { get; set; }

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Http.Features
{
/// <summary>
/// A feature to gracefully end a response.
/// </summary>
public interface IHttpResponseCompletionFeature
{
/// <summary>
/// Flush any remaining response headers, data, or trailers.
/// This may throw if the response is in an invalid state such as a Content-Length mismatch.
/// </summary>
/// <returns></returns>
Task CompleteAsync();
}
}

View File

@ -277,6 +277,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
protected void ResetHttp2Features()
{
_currentIHttp2StreamIdFeature = this;
_currentIHttpResponseCompletionFeature = this;
_currentIHttpResponseTrailersFeature = this;
}

View File

@ -29,6 +29,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private static readonly Type IFormFeatureType = typeof(IFormFeature);
private static readonly Type IHttpUpgradeFeatureType = typeof(IHttpUpgradeFeature);
private static readonly Type IHttp2StreamIdFeatureType = typeof(IHttp2StreamIdFeature);
private static readonly Type IHttpResponseCompletionFeatureType = typeof(IHttpResponseCompletionFeature);
private static readonly Type IHttpResponseTrailersFeatureType = typeof(IHttpResponseTrailersFeature);
private static readonly Type IResponseCookiesFeatureType = typeof(IResponseCookiesFeature);
private static readonly Type IItemsFeatureType = typeof(IItemsFeature);
@ -58,6 +59,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private object _currentIFormFeature;
private object _currentIHttpUpgradeFeature;
private object _currentIHttp2StreamIdFeature;
private object _currentIHttpResponseCompletionFeature;
private object _currentIHttpResponseTrailersFeature;
private object _currentIResponseCookiesFeature;
private object _currentIItemsFeature;
@ -98,6 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_currentIQueryFeature = null;
_currentIFormFeature = null;
_currentIHttp2StreamIdFeature = null;
_currentIHttpResponseCompletionFeature = null;
_currentIHttpResponseTrailersFeature = null;
_currentIResponseCookiesFeature = null;
_currentIItemsFeature = null;
@ -224,6 +227,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
feature = _currentIHttp2StreamIdFeature;
}
else if (key == IHttpResponseCompletionFeatureType)
{
feature = _currentIHttpResponseCompletionFeature;
}
else if (key == IHttpResponseTrailersFeatureType)
{
feature = _currentIHttpResponseTrailersFeature;
@ -348,6 +355,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
_currentIHttp2StreamIdFeature = value;
}
else if (key == IHttpResponseCompletionFeatureType)
{
_currentIHttpResponseCompletionFeature = value;
}
else if (key == IHttpResponseTrailersFeatureType)
{
_currentIHttpResponseTrailersFeature = value;
@ -470,6 +481,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
feature = (TFeature)_currentIHttp2StreamIdFeature;
}
else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
{
feature = (TFeature)_currentIHttpResponseCompletionFeature;
}
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
feature = (TFeature)_currentIHttpResponseTrailersFeature;
@ -598,6 +613,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
_currentIHttp2StreamIdFeature = feature;
}
else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
{
_currentIHttpResponseCompletionFeature = feature;
}
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
_currentIHttpResponseTrailersFeature = feature;
@ -718,6 +737,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
yield return new KeyValuePair<Type, object>(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature);
}
if (_currentIHttpResponseCompletionFeature != null)
{
yield return new KeyValuePair<Type, object>(IHttpResponseCompletionFeatureType, _currentIHttpResponseCompletionFeature);
}
if (_currentIHttpResponseTrailersFeature != null)
{
yield return new KeyValuePair<Type, object>(IHttpResponseTrailersFeatureType, _currentIHttpResponseTrailersFeature);

View File

@ -210,6 +210,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public bool RequestTrailersAvailable { get; set; }
public Stream RequestBody { get; set; }
public PipeReader RequestBodyPipeReader { get; set; }
public HttpResponseTrailers ResponseTrailers { get; set; }
private int _statusCode;
public int StatusCode
@ -287,7 +288,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public bool HasResponseStarted => _requestProcessingStatus >= RequestProcessingStatus.HeadersCommitted;
public bool HasFlushedHeaders => _requestProcessingStatus == RequestProcessingStatus.HeadersFlushed;
public bool HasFlushedHeaders => _requestProcessingStatus >= RequestProcessingStatus.HeadersFlushed;
public bool HasResponseCompleted => _requestProcessingStatus == RequestProcessingStatus.ResponseCompleted;
protected HttpRequestHeaders HttpRequestHeaders { get; }
@ -632,9 +635,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
// Run the application code for this request
await application.ProcessRequestAsync(context);
if (!_connectionAborted)
// Trigger OnStarting if it hasn't been called yet and the app hasn't
// already failed. If an OnStarting callback throws we can go through
// our normal error handling in ProduceEnd.
// https://github.com/aspnet/KestrelHttpServer/issues/43
if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
{
VerifyResponseContentLength();
await FireOnStarting();
}
if (!_connectionAborted && !VerifyResponseContentLength(out var lengthException))
{
ReportApplicationError(lengthException);
}
}
catch (BadHttpRequestException ex)
@ -652,15 +664,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
KestrelEventSource.Log.RequestStop(this);
// Trigger OnStarting if it hasn't been called yet and the app hasn't
// already failed. If an OnStarting callback throws we can go through
// our normal error handling in ProduceEnd.
// https://github.com/aspnet/KestrelHttpServer/issues/43
if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
{
await FireOnStarting();
}
// At this point all user code that needs use to the request or response streams has completed.
// Using these streams in the OnCompleted callback is not allowed.
StopBodies();
@ -898,7 +901,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
protected void VerifyResponseContentLength()
protected bool VerifyResponseContentLength(out Exception ex)
{
var responseHeaders = HttpResponseHeaders;
@ -915,9 +918,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_keepAlive = false;
}
ReportApplicationError(new InvalidOperationException(
CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value)));
ex = new InvalidOperationException(
CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value));
return false;
}
ex = null;
return true;
}
public void ProduceContinue()
@ -1045,6 +1052,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private Task WriteSuffix()
{
if (HasResponseCompleted)
{
return Task.CompletedTask;
}
// _autoChunk should be checked after we are sure ProduceStart() has been called
// since ProduceStart() may set _autoChunk to true.
if (_autoChunk || _httpVersion == Http.HttpVersion.Http2)
@ -1064,7 +1076,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
if (!HasFlushedHeaders)
{
_requestProcessingStatus = RequestProcessingStatus.HeadersFlushed;
_requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
return FlushAsyncInternal();
}
@ -1080,6 +1092,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
await Output.WriteStreamSuffixAsync();
_requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
if (_keepAlive)
{
Log.ConnectionKeepAlive(ConnectionId);
@ -1244,6 +1258,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
var responseHeaders = HttpResponseHeaders;
responseHeaders.Reset();
ResponseTrailers?.Reset();
var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues();
responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes);

View File

@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ParsingHeaders,
AppStarted,
HeadersCommitted,
HeadersFlushed
HeadersFlushed,
ResponseCompleted
}
}

View File

@ -151,7 +151,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
// 2. There is no trailing HEADERS frame.
Http2HeadersFrameFlags http2HeadersFrame;
if (appCompleted && !_startedWritingDataFrames && (_stream.Trailers == null || _stream.Trailers.Count == 0))
if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0))
{
_streamEnded = true;
http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM;
@ -313,7 +313,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
readResult = await _dataPipe.Reader.ReadAsync();
if (readResult.IsCompleted && _stream.Trailers?.Count > 0)
if (readResult.IsCompleted && _stream.ResponseTrailers?.Count > 0)
{
// Output is ending and there are trailers to write
// Write any remaining content then write trailers
@ -322,7 +322,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
flushResult = await _frameWriter.WriteDataAsync(_streamId, _flowControl, readResult.Buffer, endStream: false);
}
flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.Trailers);
_stream.ResponseTrailers.SetReadOnly();
flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.ResponseTrailers);
}
else if (readResult.IsCompleted && _streamEnded)
{

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
@ -11,21 +12,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
internal partial class Http2Stream : IHttp2StreamIdFeature,
IHttpMinRequestBodyDataRateFeature,
IHttpResponseCompletionFeature,
IHttpResponseTrailersFeature
{
internal HttpResponseTrailers Trailers { get; set; }
private IHeaderDictionary _userTrailers;
IHeaderDictionary IHttpResponseTrailersFeature.Trailers
{
get
{
if (Trailers == null)
if (ResponseTrailers == null)
{
Trailers = new HttpResponseTrailers();
ResponseTrailers = new HttpResponseTrailers();
if (HasResponseCompleted)
{
ResponseTrailers.SetReadOnly();
}
}
return _userTrailers ?? Trailers;
return _userTrailers ?? ResponseTrailers;
}
set
{
@ -48,5 +53,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
MinRequestBodyDataRate = value;
}
}
async Task IHttpResponseCompletionFeature.CompleteAsync()
{
// Finalize headers
if (!HasResponseStarted)
{
await FireOnStarting();
}
// Flush headers, body, trailers...
if (!HasResponseCompleted)
{
if (!VerifyResponseContentLength(out var lengthException))
{
throw lengthException;
}
await ProduceEnd();
}
}
}
}

View File

@ -1839,6 +1839,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
[Fact]
public async Task ResponseTrailers_WithExeption500_Cleared()
{
await InitializeConnectionAsync(context =>
{
context.Response.AppendTrailer("CustomName", "Custom Value");
throw new NotImplementedException("Test Exception");
});
await StartStreamAsync(1, _browserRequestHeaders, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM | Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task ResponseTrailers_WithData_Sent()
{
@ -3307,5 +3333,779 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Contains(TestSink.Writes, w => w.EventId.Id == 13 && w.LogLevel == LogLevel.Error
&& w.Exception is ConnectionAbortedException && w.Exception.InnerException == expectedException);
}
[Fact]
public async Task CompleteAsync_BeforeBodyStarted_SendsHeadersWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders["content-length"]);
}
[Fact]
public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_SendsHeadersAndTrailersWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
context.Response.AppendTrailer("CustomName", "Custom Value");
await completionFeature.CompleteAsync().DefaultTimeout();
await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 25,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders["content-length"]);
_decodedHeaders.Clear();
_hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Single(_decodedHeaders);
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
[Fact]
public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAnd500()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
context.Response.ContentLength = 25;
context.Response.AppendTrailer("CustomName", "Custom Value");
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => completionFeature.CompleteAsync().DefaultTimeout());
Assert.Equal(CoreStrings.FormatTooFewBytesWritten(0, 25), ex.Message);
Assert.True(startingTcs.Task.IsCompletedSuccessfully);
Assert.False(context.Response.Headers.IsReadOnly);
Assert.False(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task CompleteAsync_AfterBodyStarted_SendsBodyWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await context.Response.WriteAsync("Hello World");
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
await completionFeature.CompleteAsync().DefaultTimeout();
await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
}
[Fact]
public async Task CompleteAsync_WriteAfterComplete_Throws()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task CompleteAsync_WriteAgainAfterComplete_Throws()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await context.Response.WriteAsync("Hello World").DefaultTimeout();
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
}
[Fact]
public async Task CompleteAsync_AfterPipeWrite_WithTrailers_SendsBodyAndTrailersWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
var buffer = context.Response.BodyWriter.GetMemory();
var length = Encoding.UTF8.GetBytes("Hello World", buffer.Span);
context.Response.BodyWriter.Advance(length);
Assert.False(startingTcs.Task.IsCompletedSuccessfully); // OnStarting did not get called.
Assert.False(context.Response.Headers.IsReadOnly);
context.Response.AppendTrailer("CustomName", "Custom Value");
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 25,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
_decodedHeaders.Clear();
_hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Single(_decodedHeaders);
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
[Fact]
public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await context.Response.WriteAsync("Hello World");
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
context.Response.AppendTrailer("CustomName", "Custom Value");
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 25,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
_decodedHeaders.Clear();
_hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Single(_decodedHeaders);
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
[Fact]
public async Task CompleteAsync_AfterBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAndReset()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
context.Response.ContentLength = 25;
await context.Response.WriteAsync("Hello World");
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
context.Response.AppendTrailer("CustomName", "Custom Value");
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => completionFeature.CompleteAsync().DefaultTimeout());
Assert.Equal(CoreStrings.FormatTooFewBytesWritten(11, 25), ex.Message);
Assert.False(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 56,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
clientTcs.SetResult(0);
await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR,
expectedErrorMessage: CoreStrings.FormatTooFewBytesWritten(11, 25));
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("25", _decodedHeaders[HeaderNames.ContentLength]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
}
[Fact]
public async Task AbortAfterCompleteAsync_GETWithResponseBodyAndTrailers_ResetsAfterResponse()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "GET"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await context.Response.WriteAsync("Hello World");
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
context.Response.AppendTrailer("CustomName", "Custom Value");
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// RequestAborted will no longer fire after CompleteAsync.
Assert.False(context.RequestAborted.CanBeCanceled);
context.Abort();
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 25,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
_decodedHeaders.Clear();
_hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Single(_decodedHeaders);
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
[Fact]
public async Task AbortAfterCompleteAsync_POSTWithResponseBodyAndTrailers_RequestBodyThrows()
{
var startingTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var appTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var clientTcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
try
{
var requestBodyTask = context.Request.BodyReader.ReadAsync();
context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
var completionFeature = context.Features.Get<IHttpResponseCompletionFeature>();
Assert.NotNull(completionFeature);
await context.Response.WriteAsync("Hello World");
Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
Assert.True(context.Response.Headers.IsReadOnly);
context.Response.AppendTrailer("CustomName", "Custom Value");
await completionFeature.CompleteAsync().DefaultTimeout();
Assert.True(context.Features.Get<IHttpResponseTrailersFeature>().Trailers.IsReadOnly);
// RequestAborted will no longer fire after CompleteAsync.
Assert.False(context.RequestAborted.CanBeCanceled);
context.Abort();
await Assert.ThrowsAsync<TaskCanceledException>(async () => await requestBodyTask);
await Assert.ThrowsAsync<ConnectionAbortedException>(async () => await context.Request.BodyReader.ReadAsync());
// Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
await clientTcs.Task.DefaultTimeout();
appTcs.SetResult(0);
}
catch (Exception ex)
{
appTcs.SetException(ex);
}
});
await StartStreamAsync(1, headers, endStream: false);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 37,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
withStreamId: 1);
var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 11,
withFlags: (byte)(Http2HeadersFrameFlags.NONE),
withStreamId: 1);
var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 25,
withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
withStreamId: 1);
await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
clientTcs.SetResult(0);
await appTcs.Task;
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
_decodedHeaders.Clear();
_hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
Assert.Single(_decodedHeaders);
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
}
}

View File

@ -35,6 +35,7 @@ namespace CodeGenerator
{
"IHttpUpgradeFeature",
"IHttp2StreamIdFeature",
"IHttpResponseCompletionFeature",
"IHttpResponseTrailersFeature",
"IResponseCookiesFeature",
"IItemsFeature",