Some minor clean up of Stream implementations (#14986)
* Some minor clean up of Stream implementations - Use TaskToApm from corefx to implement Begin/End in streams - Use PipeReader.CopyToAsync(Stream) to implement CopyToAsync - Add more overrides on derived Streams in IIS
This commit is contained in:
parent
16be9a264e
commit
9098a47dbf
|
|
@ -60,9 +60,29 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
return _requestBody.ReadAsync(buffer, offset, count, cancellationToken);
|
||||
}
|
||||
|
||||
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return _requestBody.ReadAsync(buffer, cancellationToken);
|
||||
}
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
return _responseBody.WriteAsync(buffer, offset, count, cancellationToken);
|
||||
}
|
||||
|
||||
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return _responseBody.WriteAsync(buffer, cancellationToken);
|
||||
}
|
||||
|
||||
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
|
||||
{
|
||||
return _requestBody.CopyToAsync(destination, bufferSize, cancellationToken);
|
||||
}
|
||||
|
||||
public override Task FlushAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return _responseBody.FlushAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Runtime.ExceptionServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
|
@ -35,40 +36,12 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
|
||||
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = ReadAsync(buffer, offset, count, default(CancellationToken), state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override int EndRead(IAsyncResult asyncResult)
|
||||
{
|
||||
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<int>(state);
|
||||
var task = ReadAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<int>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(task2.Result);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
return TaskToApm.End<int>(asyncResult);
|
||||
}
|
||||
|
||||
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
|
|
@ -97,6 +70,18 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
}
|
||||
}
|
||||
|
||||
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _body.CopyToAsync(destination, cancellationToken);
|
||||
}
|
||||
catch (ConnectionAbortedException ex)
|
||||
{
|
||||
throw new TaskCanceledException("The request was aborted", ex);
|
||||
}
|
||||
}
|
||||
|
||||
public void StartAcceptingReads(IISHttpContext body)
|
||||
{
|
||||
// Only start if not aborted
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
|
||||
public override void Flush()
|
||||
{
|
||||
FlushAsync(default(CancellationToken)).GetAwaiter().GetResult();
|
||||
FlushAsync(default).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
public override Task FlushAsync(CancellationToken cancellationToken)
|
||||
|
|
@ -41,45 +41,17 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
throw new InvalidOperationException(CoreStrings.SynchronousWritesDisallowed);
|
||||
}
|
||||
|
||||
WriteAsync(buffer, offset, count, default(CancellationToken)).GetAwaiter().GetResult();
|
||||
WriteAsync(buffer, offset, count, default).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = WriteAsync(buffer, offset, count, default(CancellationToken), state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override void EndWrite(IAsyncResult asyncResult)
|
||||
{
|
||||
((Task<object>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(state);
|
||||
var task = WriteAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<object>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(null);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
TaskToApm.End(asyncResult);
|
||||
}
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@
|
|||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Net.Http;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.IIS.Core
|
||||
|
|
@ -54,6 +53,16 @@ namespace Microsoft.AspNetCore.Server.IIS.Core
|
|||
}
|
||||
}
|
||||
|
||||
internal Task CopyToAsync(Stream destination, CancellationToken cancellationToken)
|
||||
{
|
||||
if (!HasStartedConsumingRequestBody)
|
||||
{
|
||||
InitializeRequestIO();
|
||||
}
|
||||
|
||||
return _bodyInputPipe.Reader.CopyToAsync(destination, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes data to the output pipe.
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
<Compile Include="$(SharedSourceRoot)StackTrace\**\*.cs" LinkBase="Shared\" />
|
||||
<Compile Include="$(SharedSourceRoot)RazorViews\*.cs" LinkBase="Shared\" />
|
||||
<Compile Include="$(SharedSourceRoot)ErrorPage\*.cs" LinkBase="Shared\" />
|
||||
<Compile Include="$(RepoRoot)src\Shared\TaskToApm.cs" Link="Shared\TaskToApm.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<Target Name="ValidateNativeComponentsBuilt" AfterTargets="Build" >
|
||||
|
|
|
|||
|
|
@ -90,41 +90,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
|
||||
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = ReadAsync(buffer, offset, count, default, state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override int EndRead(IAsyncResult asyncResult)
|
||||
{
|
||||
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<int>(state);
|
||||
var task = ReadAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<int>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(task2.Result);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
return TaskToApm.End<int>(asyncResult);
|
||||
}
|
||||
|
||||
private ValueTask<int> ReadAsyncWrapper(Memory<byte> destination, CancellationToken cancellationToken)
|
||||
|
|
@ -139,7 +111,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
}
|
||||
}
|
||||
|
||||
private async ValueTask<int> ReadAsyncInternal(Memory<byte> buffer, CancellationToken cancellationToken)
|
||||
private async ValueTask<int> ReadAsyncInternal(Memory<byte> destination, CancellationToken cancellationToken)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
|
|
@ -150,19 +122,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
throw new OperationCanceledException("The read was canceled");
|
||||
}
|
||||
|
||||
var readableBuffer = result.Buffer;
|
||||
var readableBufferLength = readableBuffer.Length;
|
||||
var buffer = result.Buffer;
|
||||
var length = buffer.Length;
|
||||
|
||||
var consumed = readableBuffer.End;
|
||||
var consumed = buffer.End;
|
||||
try
|
||||
{
|
||||
if (readableBufferLength != 0)
|
||||
if (length != 0)
|
||||
{
|
||||
var actual = (int)Math.Min(readableBufferLength, buffer.Length);
|
||||
var actual = (int)Math.Min(length, destination.Length);
|
||||
|
||||
var slice = actual == readableBufferLength ? readableBuffer : readableBuffer.Slice(0, actual);
|
||||
var slice = actual == length ? buffer : buffer.Slice(0, actual);
|
||||
consumed = slice.End;
|
||||
slice.CopyTo(buffer.Span);
|
||||
slice.CopyTo(destination.Span);
|
||||
|
||||
return actual;
|
||||
}
|
||||
|
|
@ -193,37 +165,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
throw new ArgumentOutOfRangeException(nameof(bufferSize));
|
||||
}
|
||||
|
||||
return CopyToAsyncInternal(destination, cancellationToken);
|
||||
}
|
||||
|
||||
private async Task CopyToAsyncInternal(Stream destination, CancellationToken cancellationToken)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var result = await _pipeReader.ReadAsync(cancellationToken);
|
||||
var readableBuffer = result.Buffer;
|
||||
var readableBufferLength = readableBuffer.Length;
|
||||
|
||||
try
|
||||
{
|
||||
if (readableBufferLength != 0)
|
||||
{
|
||||
foreach (var memory in readableBuffer)
|
||||
{
|
||||
await destination.WriteAsync(memory, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
if (result.IsCompleted)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_pipeReader.AdvanceTo(readableBuffer.End);
|
||||
}
|
||||
}
|
||||
return _pipeReader.CopyToAsync(destination, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
|
|
@ -87,40 +86,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
|
||||
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = WriteAsync(buffer, offset, count, default, state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override void EndWrite(IAsyncResult asyncResult)
|
||||
{
|
||||
((Task<object>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(state);
|
||||
var task = WriteAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<object>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(null);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
TaskToApm.End(asyncResult);
|
||||
}
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
<Compile Include="$(SharedSourceRoot)CertificateGeneration\**\*.cs" />
|
||||
<Compile Include="$(SharedSourceRoot)ValueTaskExtensions\**\*.cs" />
|
||||
<Compile Include="$(SharedSourceRoot)UrlDecoder\**\*.cs" />
|
||||
<Compile Include="$(RepoRoot)src\Shared\TaskToApm.cs" Link="Internal\TaskToApm.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -152,78 +152,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
|
|||
|
||||
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = ReadAsync(buffer, offset, count, default, state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override int EndRead(IAsyncResult asyncResult)
|
||||
{
|
||||
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<int>(state);
|
||||
var task = ReadAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<int>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(task2.Result);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
return TaskToApm.End<int>(asyncResult);
|
||||
}
|
||||
|
||||
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = WriteAsync(buffer, offset, count, default, state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override void EndWrite(IAsyncResult asyncResult)
|
||||
{
|
||||
((Task<object>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(state);
|
||||
var task = WriteAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<object>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(null);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
TaskToApm.End(asyncResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,78 +176,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
|
|||
// The below APM methods call the underlying Read/WriteAsync methods which will still be logged.
|
||||
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = ReadAsync(buffer, offset, count, default(CancellationToken), state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override int EndRead(IAsyncResult asyncResult)
|
||||
{
|
||||
return ((Task<int>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<int>(state);
|
||||
var task = ReadAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<int>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(task2.Result);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
return TaskToApm.End<int>(asyncResult);
|
||||
}
|
||||
|
||||
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
|
||||
{
|
||||
var task = WriteAsync(buffer, offset, count, default(CancellationToken), state);
|
||||
if (callback != null)
|
||||
{
|
||||
task.ContinueWith(t => callback.Invoke(t));
|
||||
}
|
||||
return task;
|
||||
return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state);
|
||||
}
|
||||
|
||||
public override void EndWrite(IAsyncResult asyncResult)
|
||||
{
|
||||
((Task<object>)asyncResult).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(state);
|
||||
var task = WriteAsync(buffer, offset, count, cancellationToken);
|
||||
task.ContinueWith((task2, state2) =>
|
||||
{
|
||||
var tcs2 = (TaskCompletionSource<object>)state2;
|
||||
if (task2.IsCanceled)
|
||||
{
|
||||
tcs2.SetCanceled();
|
||||
}
|
||||
else if (task2.IsFaulted)
|
||||
{
|
||||
tcs2.SetException(task2.Exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs2.SetResult(null);
|
||||
}
|
||||
}, tcs, cancellationToken);
|
||||
return tcs.Task;
|
||||
TaskToApm.End(asyncResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
// Helper methods for using Tasks to implement the APM pattern.
|
||||
//
|
||||
// Example usage, wrapping a Task<int>-returning FooAsync method with Begin/EndFoo methods:
|
||||
//
|
||||
// public IAsyncResult BeginFoo(..., AsyncCallback callback, object state) =>
|
||||
// TaskToApm.Begin(FooAsync(...), callback, state);
|
||||
//
|
||||
// public int EndFoo(IAsyncResult asyncResult) =>
|
||||
// TaskToApm.End<int>(asyncResult);
|
||||
|
||||
#nullable enable
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace System.Threading.Tasks
|
||||
{
|
||||
/// <summary>
|
||||
/// Provides support for efficiently using Tasks to implement the APM (Begin/End) pattern.
|
||||
/// </summary>
|
||||
internal static class TaskToApm
|
||||
{
|
||||
/// <summary>
|
||||
/// Marshals the Task as an IAsyncResult, using the supplied callback and state
|
||||
/// to implement the APM pattern.
|
||||
/// </summary>
|
||||
/// <param name="task">The Task to be marshaled.</param>
|
||||
/// <param name="callback">The callback to be invoked upon completion.</param>
|
||||
/// <param name="state">The state to be stored in the IAsyncResult.</param>
|
||||
/// <returns>An IAsyncResult to represent the task's asynchronous operation.</returns>
|
||||
public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? state) =>
|
||||
new TaskAsyncResult(task, state, callback);
|
||||
|
||||
/// <summary>Processes an IAsyncResult returned by Begin.</summary>
|
||||
/// <param name="asyncResult">The IAsyncResult to unwrap.</param>
|
||||
public static void End(IAsyncResult asyncResult)
|
||||
{
|
||||
if (asyncResult is TaskAsyncResult twar)
|
||||
{
|
||||
twar._task.GetAwaiter().GetResult();
|
||||
return;
|
||||
}
|
||||
|
||||
throw new ArgumentNullException();
|
||||
}
|
||||
|
||||
/// <summary>Processes an IAsyncResult returned by Begin.</summary>
|
||||
/// <param name="asyncResult">The IAsyncResult to unwrap.</param>
|
||||
public static TResult End<TResult>(IAsyncResult asyncResult)
|
||||
{
|
||||
if (asyncResult is TaskAsyncResult twar && twar._task is Task<TResult> task)
|
||||
{
|
||||
return task.GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
throw new ArgumentNullException();
|
||||
}
|
||||
|
||||
/// <summary>Provides a simple IAsyncResult that wraps a Task.</summary>
|
||||
/// <remarks>
|
||||
/// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state,
|
||||
/// but that's very rare, in particular in a situation where someone cares about allocation, and always
|
||||
/// using TaskAsyncResult simplifies things and enables additional optimizations.
|
||||
/// </remarks>
|
||||
internal sealed class TaskAsyncResult : IAsyncResult
|
||||
{
|
||||
/// <summary>The wrapped Task.</summary>
|
||||
internal readonly Task _task;
|
||||
/// <summary>Callback to invoke when the wrapped task completes.</summary>
|
||||
private readonly AsyncCallback? _callback;
|
||||
|
||||
/// <summary>Initializes the IAsyncResult with the Task to wrap and the associated object state.</summary>
|
||||
/// <param name="task">The Task to wrap.</param>
|
||||
/// <param name="state">The new AsyncState value.</param>
|
||||
/// <param name="callback">Callback to invoke when the wrapped task completes.</param>
|
||||
internal TaskAsyncResult(Task task, object? state, AsyncCallback? callback)
|
||||
{
|
||||
Debug.Assert(task != null);
|
||||
_task = task;
|
||||
AsyncState = state;
|
||||
|
||||
if (task.IsCompleted)
|
||||
{
|
||||
// Synchronous completion. Invoke the callback. No need to store it.
|
||||
CompletedSynchronously = true;
|
||||
callback?.Invoke(this);
|
||||
}
|
||||
else if (callback != null)
|
||||
{
|
||||
// Asynchronous completion, and we have a callback; schedule it. We use OnCompleted rather than ContinueWith in
|
||||
// order to avoid running synchronously if the task has already completed by the time we get here but still run
|
||||
// synchronously as part of the task's completion if the task completes after (the more common case).
|
||||
_callback = callback;
|
||||
_task.ConfigureAwait(continueOnCapturedContext: false)
|
||||
.GetAwaiter()
|
||||
.OnCompleted(InvokeCallback); // allocates a delegate, but avoids a closure
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Invokes the callback.</summary>
|
||||
private void InvokeCallback()
|
||||
{
|
||||
Debug.Assert(!CompletedSynchronously);
|
||||
Debug.Assert(_callback != null);
|
||||
_callback.Invoke(this);
|
||||
}
|
||||
|
||||
/// <summary>Gets a user-defined object that qualifies or contains information about an asynchronous operation.</summary>
|
||||
public object? AsyncState { get; }
|
||||
/// <summary>Gets a value that indicates whether the asynchronous operation completed synchronously.</summary>
|
||||
/// <remarks>This is set lazily based on whether the <see cref="_task"/> has completed by the time this object is created.</remarks>
|
||||
public bool CompletedSynchronously { get; }
|
||||
/// <summary>Gets a value that indicates whether the asynchronous operation has completed.</summary>
|
||||
public bool IsCompleted => _task.IsCompleted;
|
||||
/// <summary>Gets a <see cref="WaitHandle"/> that is used to wait for an asynchronous operation to complete.</summary>
|
||||
public WaitHandle AsyncWaitHandle => ((IAsyncResult)_task).AsyncWaitHandle;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue