Use ForceAsyncAwaiter to make sure we're not on the app SyncContext (#758)
This commit is contained in:
parent
83f3605cfb
commit
a200cd46b1
|
|
@ -1,6 +1,6 @@
|
|||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio 15
|
||||
VisualStudioVersion = 15.0.26730.0
|
||||
VisualStudioVersion = 15.0.26823.1
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
|
||||
EndProject
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Internal
|
||||
{
|
||||
public static class ForceAsyncTaskExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Returns an awaitable/awaiter that will ensure the continuation is executed
|
||||
/// asynchronously on the thread pool, even if the task is already completed
|
||||
/// by the time the await occurs. Effectively, it is equivalent to awaiting
|
||||
/// with ConfigureAwait(false) and then queuing the continuation with Task.Run,
|
||||
/// but it avoids the extra hop if the continuation already executed asynchronously.
|
||||
/// </summary>
|
||||
public static ForceAsyncAwaiter ForceAsync(this Task task)
|
||||
{
|
||||
return new ForceAsyncAwaiter(task);
|
||||
}
|
||||
|
||||
public static ForceAsyncAwaiter<T> ForceAsync<T>(this Task<T> task)
|
||||
{
|
||||
return new ForceAsyncAwaiter<T>(task);
|
||||
}
|
||||
}
|
||||
|
||||
public struct ForceAsyncAwaiter : ICriticalNotifyCompletion
|
||||
{
|
||||
private readonly Task _task;
|
||||
|
||||
internal ForceAsyncAwaiter(Task task) { _task = task; }
|
||||
|
||||
public ForceAsyncAwaiter GetAwaiter() { return this; }
|
||||
|
||||
public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
|
||||
|
||||
public void GetResult() { _task.GetAwaiter().GetResult(); }
|
||||
|
||||
public void OnCompleted(Action action)
|
||||
{
|
||||
_task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
|
||||
}
|
||||
|
||||
public void UnsafeOnCompleted(Action action)
|
||||
{
|
||||
_task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
|
||||
}
|
||||
}
|
||||
|
||||
public struct ForceAsyncAwaiter<T> : ICriticalNotifyCompletion
|
||||
{
|
||||
private readonly Task<T> _task;
|
||||
|
||||
internal ForceAsyncAwaiter(Task<T> task) { _task = task; }
|
||||
|
||||
public ForceAsyncAwaiter<T> GetAwaiter() { return this; }
|
||||
|
||||
public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
|
||||
|
||||
public T GetResult() { return _task.GetAwaiter().GetResult(); }
|
||||
|
||||
public void OnCompleted(Action action)
|
||||
{
|
||||
_task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
|
||||
}
|
||||
|
||||
public void UnsafeOnCompleted(Action action)
|
||||
{
|
||||
_task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -6,7 +6,6 @@ using System.Collections.Concurrent;
|
|||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
|
|
@ -17,6 +16,7 @@ using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
|||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Client;
|
||||
using Microsoft.AspNetCore.Sockets.Features;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Newtonsoft.Json;
|
||||
|
|
@ -81,7 +81,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
_connection.Closed += Shutdown;
|
||||
}
|
||||
|
||||
public async Task StartAsync()
|
||||
public async Task StartAsync() => await StartAsyncCore().ForceAsync();
|
||||
|
||||
private async Task StartAsyncCore()
|
||||
{
|
||||
var transferModeFeature = _connection.Features.Get<ITransferModeFeature>();
|
||||
if (transferModeFeature == null)
|
||||
|
|
@ -122,7 +124,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
return new PassThroughEncoder();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
|
||||
|
||||
private async Task DisposeAsyncCore()
|
||||
{
|
||||
await _connection.DisposeAsync();
|
||||
}
|
||||
|
|
@ -141,14 +145,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
return channel;
|
||||
}
|
||||
|
||||
public async Task<object> InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
|
||||
public async Task<object> InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) =>
|
||||
await InvokeAsyncCore(methodName, returnType, cancellationToken, args).ForceAsync();
|
||||
|
||||
private async Task<object> InvokeAsyncCore(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args)
|
||||
{
|
||||
var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task);
|
||||
await InvokeCore(methodName, irq, args, nonBlocking: false);
|
||||
return await task;
|
||||
}
|
||||
|
||||
public Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args)
|
||||
public async Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args) =>
|
||||
await SendAsyncCore(methodName, cancellationToken, args).ForceAsync();
|
||||
|
||||
private Task SendAsyncCore(string methodName, CancellationToken cancellationToken, params object[] args)
|
||||
{
|
||||
var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, out _);
|
||||
return InvokeCore(methodName, irq, args, nonBlocking: true);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@
|
|||
<EnableApiCheck>false</EnableApiCheck>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\Common\ForceAsyncAwaiter.cs" Link="ForceAsyncAwaiter.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
|
||||
|
|
|
|||
|
|
@ -88,7 +88,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
_transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory));
|
||||
}
|
||||
|
||||
public Task StartAsync()
|
||||
public async Task StartAsync() => await StartAsyncCore().ForceAsync();
|
||||
|
||||
private Task StartAsyncCore()
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial)
|
||||
!= ConnectionState.Initial)
|
||||
|
|
@ -357,7 +359,10 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
_logger.EndReceive(_connectionId);
|
||||
}
|
||||
|
||||
public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken))
|
||||
public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken)) =>
|
||||
await SendAsyncCore(data, cancellationToken).ForceAsync();
|
||||
|
||||
private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToken)
|
||||
{
|
||||
if (data == null)
|
||||
{
|
||||
|
|
@ -389,7 +394,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
}
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
|
||||
|
||||
private async Task DisposeAsyncCore()
|
||||
{
|
||||
_logger.StoppingClient(_connectionId);
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@
|
|||
<EnableApiCheck>false</EnableApiCheck>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\Common\ForceAsyncAwaiter.cs" Link="ForceAsyncAwaiter.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />
|
||||
|
|
|
|||
Loading…
Reference in New Issue