Some minor clean up of Stream implementations (#14986)

* Some minor clean up of Stream implementations
- Use TaskToApm from corefx to implement Begin/End in streams
- Use PipeReader.CopyToAsync(Stream) to implement CopyToAsync
- Add more overrides on derived Streams in IIS
This commit is contained in:
David Fowler 2019-10-16 22:04:29 -07:00 committed by GitHub
parent 16be9a264e
commit 9098a47dbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 194 additions and 284 deletions

View File

@ -60,9 +60,29 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
return _requestBody.ReadAsync(buffer, offset, count, cancellationToken);
}
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _requestBody.ReadAsync(buffer, cancellationToken);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _responseBody.WriteAsync(buffer, offset, count, cancellationToken);
}
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _responseBody.WriteAsync(buffer, cancellationToken);
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
return _requestBody.CopyToAsync(destination, bufferSize, cancellationToken);
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _responseBody.FlushAsync(cancellationToken);
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
@ -35,40 +36,12 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = ReadAsync(buffer, offset, count, default(CancellationToken), state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
}
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<int>(state);
var task = ReadAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<int>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(task2.Result);
}
}, tcs, cancellationToken);
return tcs.Task;
return TaskToApm.End<int>(asyncResult);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
@ -97,6 +70,18 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
}
}
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
try
{
await _body.CopyToAsync(destination, cancellationToken);
}
catch (ConnectionAbortedException ex)
{
throw new TaskCanceledException("The request was aborted", ex);
}
}
public void StartAcceptingReads(IISHttpContext body)
{
// Only start if not aborted

View File

@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
public override void Flush()
{
FlushAsync(default(CancellationToken)).GetAwaiter().GetResult();
FlushAsync(default).GetAwaiter().GetResult();
}
public override Task FlushAsync(CancellationToken cancellationToken)
@ -41,45 +41,17 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
throw new InvalidOperationException(CoreStrings.SynchronousWritesDisallowed);
}
WriteAsync(buffer, offset, count, default(CancellationToken)).GetAwaiter().GetResult();
WriteAsync(buffer, offset, count, default).GetAwaiter().GetResult();
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = WriteAsync(buffer, offset, count, default(CancellationToken), state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
((Task<object>)asyncResult).GetAwaiter().GetResult();
}
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<object>(state);
var task = WriteAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<object>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(null);
}
}, tcs, cancellationToken);
return tcs.Task;
TaskToApm.End(asyncResult);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)

View File

@ -3,11 +3,10 @@
using System;
using System.Buffers;
using System.Net.Http;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Server.IIS.Core
@ -54,6 +53,16 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
}
}
internal Task CopyToAsync(Stream destination, CancellationToken cancellationToken)
{
if (!HasStartedConsumingRequestBody)
{
InitializeRequestIO();
}
return _bodyInputPipe.Reader.CopyToAsync(destination, cancellationToken);
}
/// <summary>
/// Writes data to the output pipe.
/// </summary>

View File

@ -19,6 +19,7 @@
<Compile Include="$(SharedSourceRoot)StackTrace\**\*.cs" LinkBase="Shared\" />
<Compile Include="$(SharedSourceRoot)RazorViews\*.cs" LinkBase="Shared\" />
<Compile Include="$(SharedSourceRoot)ErrorPage\*.cs" LinkBase="Shared\" />
<Compile Include="$(RepoRoot)src\Shared\TaskToApm.cs" Link="Shared\TaskToApm.cs" />
</ItemGroup>
<Target Name="ValidateNativeComponentsBuilt" AfterTargets="Build" >

View File

@ -90,41 +90,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = ReadAsync(buffer, offset, count, default, state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
}
/// <inheritdoc />
public override int EndRead(IAsyncResult asyncResult)
{
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
}
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<int>(state);
var task = ReadAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<int>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(task2.Result);
}
}, tcs, cancellationToken);
return tcs.Task;
return TaskToApm.End<int>(asyncResult);
}
private ValueTask<int> ReadAsyncWrapper(Memory<byte> destination, CancellationToken cancellationToken)
@ -139,7 +111,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
private async ValueTask<int> ReadAsyncInternal(Memory<byte> buffer, CancellationToken cancellationToken)
private async ValueTask<int> ReadAsyncInternal(Memory<byte> destination, CancellationToken cancellationToken)
{
while (true)
{
@ -150,19 +122,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
throw new OperationCanceledException("The read was canceled");
}
var readableBuffer = result.Buffer;
var readableBufferLength = readableBuffer.Length;
var buffer = result.Buffer;
var length = buffer.Length;
var consumed = readableBuffer.End;
var consumed = buffer.End;
try
{
if (readableBufferLength != 0)
if (length != 0)
{
var actual = (int)Math.Min(readableBufferLength, buffer.Length);
var actual = (int)Math.Min(length, destination.Length);
var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual);
var slice = actual == length ? buffer : buffer.Slice(0, actual);
consumed = slice.End;
slice.CopyTo(buffer.Span);
slice.CopyTo(destination.Span);
return actual;
}
@ -193,37 +165,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
throw new ArgumentOutOfRangeException(nameof(bufferSize));
}
return CopyToAsyncInternal(destination, cancellationToken);
}
private async Task CopyToAsyncInternal(Stream destination, CancellationToken cancellationToken)
{
while (true)
{
var result = await _pipeReader.ReadAsync(cancellationToken);
var readableBuffer = result.Buffer;
var readableBufferLength = readableBuffer.Length;
try
{
if (readableBufferLength != 0)
{
foreach (var memory in readableBuffer)
{
await destination.WriteAsync(memory, cancellationToken);
}
}
if (result.IsCompleted)
{
return;
}
}
finally
{
_pipeReader.AdvanceTo(readableBuffer.End);
}
}
return _pipeReader.CopyToAsync(destination, cancellationToken);
}
}
}

View File

@ -3,7 +3,6 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
@ -87,40 +86,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = WriteAsync(buffer, offset, count, default, state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
((Task<object>)asyncResult).GetAwaiter().GetResult();
}
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<object>(state);
var task = WriteAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<object>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(null);
}
}, tcs, cancellationToken);
return tcs.Task;
TaskToApm.End(asyncResult);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)

View File

@ -15,6 +15,7 @@
<Compile Include="$(SharedSourceRoot)CertificateGeneration\**\*.cs" />
<Compile Include="$(SharedSourceRoot)ValueTaskExtensions\**\*.cs" />
<Compile Include="$(SharedSourceRoot)UrlDecoder\**\*.cs" />
<Compile Include="$(RepoRoot)src\Shared\TaskToApm.cs" Link="Internal\TaskToApm.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -152,78 +152,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = ReadAsync(buffer, offset, count, default, state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
}
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<int>(state);
var task = ReadAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<int>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(task2.Result);
}
}, tcs, cancellationToken);
return tcs.Task;
return TaskToApm.End<int>(asyncResult);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = WriteAsync(buffer, offset, count, default, state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
((Task<object>)asyncResult).GetAwaiter().GetResult();
}
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<object>(state);
var task = WriteAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<object>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(null);
}
}, tcs, cancellationToken);
return tcs.Task;
TaskToApm.End(asyncResult);
}
}
}

View File

@ -176,78 +176,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
// The below APM methods call the underlying Read/WriteAsync methods which will still be logged.
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = ReadAsync(buffer, offset, count, default(CancellationToken), state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
}
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<int>(state);
var task = ReadAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<int>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(task2.Result);
}
}, tcs, cancellationToken);
return tcs.Task;
return TaskToApm.End<int>(asyncResult);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
var task = WriteAsync(buffer, offset, count, default(CancellationToken), state);
if (callback != null)
{
task.ContinueWith(t => callback.Invoke(t));
}
return task;
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
((Task<object>)asyncResult).GetAwaiter().GetResult();
}
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
var tcs = new TaskCompletionSource<object>(state);
var task = WriteAsync(buffer, offset, count, cancellationToken);
task.ContinueWith((task2, state2) =>
{
var tcs2 = (TaskCompletionSource<object>)state2;
if (task2.IsCanceled)
{
tcs2.SetCanceled();
}
else if (task2.IsFaulted)
{
tcs2.SetException(task2.Exception);
}
else
{
tcs2.SetResult(null);
}
}, tcs, cancellationToken);
return tcs.Task;
TaskToApm.End(asyncResult);
}
}
}

121
src/Shared/TaskToApm.cs Normal file
View File

@ -0,0 +1,121 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// Helper methods for using Tasks to implement the APM pattern.
//
// Example usage, wrapping a Task<int>-returning FooAsync method with Begin/EndFoo methods:
//
// public IAsyncResult BeginFoo(..., AsyncCallback callback, object state) =>
// TaskToApm.Begin(FooAsync(...), callback, state);
//
// public int EndFoo(IAsyncResult asyncResult) =>
// TaskToApm.End<int>(asyncResult);
#nullable enable
using System.Diagnostics;
namespace System.Threading.Tasks
{
/// <summary>
/// Provides support for efficiently using Tasks to implement the APM (Begin/End) pattern.
/// </summary>
internal static class TaskToApm
{
/// <summary>
/// Marshals the Task as an IAsyncResult, using the supplied callback and state
/// to implement the APM pattern.
/// </summary>
/// <param name="task">The Task to be marshaled.</param>
/// <param name="callback">The callback to be invoked upon completion.</param>
/// <param name="state">The state to be stored in the IAsyncResult.</param>
/// <returns>An IAsyncResult to represent the task's asynchronous operation.</returns>
public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? state) =>
new TaskAsyncResult(task, state, callback);
/// <summary>Processes an IAsyncResult returned by Begin.</summary>
/// <param name="asyncResult">The IAsyncResult to unwrap.</param>
public static void End(IAsyncResult asyncResult)
{
if (asyncResult is TaskAsyncResult twar)
{
twar._task.GetAwaiter().GetResult();
return;
}
throw new ArgumentNullException();
}
/// <summary>Processes an IAsyncResult returned by Begin.</summary>
/// <param name="asyncResult">The IAsyncResult to unwrap.</param>
public static TResult End<TResult>(IAsyncResult asyncResult)
{
if (asyncResult is TaskAsyncResult twar && twar._task is Task<TResult> task)
{
return task.GetAwaiter().GetResult();
}
throw new ArgumentNullException();
}
/// <summary>Provides a simple IAsyncResult that wraps a Task.</summary>
/// <remarks>
/// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state,
/// but that's very rare, in particular in a situation where someone cares about allocation, and always
/// using TaskAsyncResult simplifies things and enables additional optimizations.
/// </remarks>
internal sealed class TaskAsyncResult : IAsyncResult
{
/// <summary>The wrapped Task.</summary>
internal readonly Task _task;
/// <summary>Callback to invoke when the wrapped task completes.</summary>
private readonly AsyncCallback? _callback;
/// <summary>Initializes the IAsyncResult with the Task to wrap and the associated object state.</summary>
/// <param name="task">The Task to wrap.</param>
/// <param name="state">The new AsyncState value.</param>
/// <param name="callback">Callback to invoke when the wrapped task completes.</param>
internal TaskAsyncResult(Task task, object? state, AsyncCallback? callback)
{
Debug.Assert(task != null);
_task = task;
AsyncState = state;
if (task.IsCompleted)
{
// Synchronous completion. Invoke the callback. No need to store it.
CompletedSynchronously = true;
callback?.Invoke(this);
}
else if (callback != null)
{
// Asynchronous completion, and we have a callback; schedule it. We use OnCompleted rather than ContinueWith in
// order to avoid running synchronously if the task has already completed by the time we get here but still run
// synchronously as part of the task's completion if the task completes after (the more common case).
_callback = callback;
_task.ConfigureAwait(continueOnCapturedContext: false)
.GetAwaiter()
.OnCompleted(InvokeCallback); // allocates a delegate, but avoids a closure
}
}
/// <summary>Invokes the callback.</summary>
private void InvokeCallback()
{
Debug.Assert(!CompletedSynchronously);
Debug.Assert(_callback != null);
_callback.Invoke(this);
}
/// <summary>Gets a user-defined object that qualifies or contains information about an asynchronous operation.</summary>
public object? AsyncState { get; }
/// <summary>Gets a value that indicates whether the asynchronous operation completed synchronously.</summary>
/// <remarks>This is set lazily based on whether the <see cref="_task"/> has completed by the time this object is created.</remarks>
public bool CompletedSynchronously { get; }
/// <summary>Gets a value that indicates whether the asynchronous operation has completed.</summary>
public bool IsCompleted => _task.IsCompleted;
/// <summary>Gets a <see cref="WaitHandle"/> that is used to wait for an asynchronous operation to complete.</summary>
public WaitHandle AsyncWaitHandle => ((IAsyncResult)_task).AsyncWaitHandle;
}
}
}