Use ForceAsyncAwaiter to make sure we're not on the app SyncContext (#758)

This commit is contained in:
Mikael Mengistu 2017-08-25 14:02:17 -07:00 committed by GitHub
parent 83f3605cfb
commit a200cd46b1
6 changed files with 109 additions and 9 deletions

View File

@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.26730.0
VisualStudioVersion = 15.0.26823.1
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject

View File

@ -0,0 +1,75 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public static class ForceAsyncTaskExtensions
{
/// <summary>
/// Returns an awaitable/awaiter that will ensure the continuation is executed
/// asynchronously on the thread pool, even if the task is already completed
/// by the time the await occurs. Effectively, it is equivalent to awaiting
/// with ConfigureAwait(false) and then queuing the continuation with Task.Run,
/// but it avoids the extra hop if the continuation already executed asynchronously.
/// </summary>
public static ForceAsyncAwaiter ForceAsync(this Task task)
{
return new ForceAsyncAwaiter(task);
}
public static ForceAsyncAwaiter<T> ForceAsync<T>(this Task<T> task)
{
return new ForceAsyncAwaiter<T>(task);
}
}
public struct ForceAsyncAwaiter : ICriticalNotifyCompletion
{
private readonly Task _task;
internal ForceAsyncAwaiter(Task task) { _task = task; }
public ForceAsyncAwaiter GetAwaiter() { return this; }
public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
public void GetResult() { _task.GetAwaiter().GetResult(); }
public void OnCompleted(Action action)
{
_task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
}
public void UnsafeOnCompleted(Action action)
{
_task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
}
}
public struct ForceAsyncAwaiter<T> : ICriticalNotifyCompletion
{
private readonly Task<T> _task;
internal ForceAsyncAwaiter(Task<T> task) { _task = task; }
public ForceAsyncAwaiter<T> GetAwaiter() { return this; }
public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
public T GetResult() { return _task.GetAwaiter().GetResult(); }
public void OnCompleted(Action action)
{
_task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
}
public void UnsafeOnCompleted(Action action)
{
_task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
}
}
}

View File

@ -6,7 +6,6 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
@ -17,6 +16,7 @@ using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Newtonsoft.Json;
@ -81,7 +81,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
_connection.Closed += Shutdown;
}
public async Task StartAsync()
public async Task StartAsync() => await StartAsyncCore().ForceAsync();
private async Task StartAsyncCore()
{
var transferModeFeature = _connection.Features.Get<ITransferModeFeature>();
if (transferModeFeature == null)
@ -122,7 +124,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
return new PassThroughEncoder();
}
public async Task DisposeAsync()
public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
private async Task DisposeAsyncCore()
{
await _connection.DisposeAsync();
}
@ -141,14 +145,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
return channel;
}
public async Task<object> InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
public async Task<object> InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) =>
await InvokeAsyncCore(methodName, returnType, cancellationToken, args).ForceAsync();
private async Task<object> InvokeAsyncCore(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
{
var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task);
await InvokeCore(methodName, irq, args, nonBlocking: false);
return await task;
}
public Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args)
public async Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args) =>
await SendAsyncCore(methodName, cancellationToken, args).ForceAsync();
private Task SendAsyncCore(string methodName, CancellationToken cancellationToken, params object[] args)
{
var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, out _);
return InvokeCore(methodName, irq, args, nonBlocking: true);

View File

@ -11,6 +11,10 @@
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\ForceAsyncAwaiter.cs" Link="ForceAsyncAwaiter.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />

View File

@ -88,7 +88,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
_transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory));
}
public Task StartAsync()
public async Task StartAsync() => await StartAsyncCore().ForceAsync();
private Task StartAsyncCore()
{
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial)
!= ConnectionState.Initial)
@ -357,7 +359,10 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger.EndReceive(_connectionId);
}
public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken))
public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken)) =>
await SendAsyncCore(data, cancellationToken).ForceAsync();
private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToken)
{
if (data == null)
{
@ -389,7 +394,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
}
public async Task DisposeAsync()
public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
private async Task DisposeAsyncCore()
{
_logger.StoppingClient(_connectionId);

View File

@ -11,6 +11,10 @@
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\ForceAsyncAwaiter.cs" Link="ForceAsyncAwaiter.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />