diff --git a/SignalR.sln b/SignalR.sln index 6f89370fc0..19a8a12b1e 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26730.0 +VisualStudioVersion = 15.0.26823.1 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" EndProject diff --git a/src/Common/ForceAsyncAwaiter.cs b/src/Common/ForceAsyncAwaiter.cs new file mode 100644 index 0000000000..b52aa8a9a6 --- /dev/null +++ b/src/Common/ForceAsyncAwaiter.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets.Internal +{ + public static class ForceAsyncTaskExtensions + { + /// + /// Returns an awaitable/awaiter that will ensure the continuation is executed + /// asynchronously on the thread pool, even if the task is already completed + /// by the time the await occurs. Effectively, it is equivalent to awaiting + /// with ConfigureAwait(false) and then queuing the continuation with Task.Run, + /// but it avoids the extra hop if the continuation already executed asynchronously. + /// + public static ForceAsyncAwaiter ForceAsync(this Task task) + { + return new ForceAsyncAwaiter(task); + } + + public static ForceAsyncAwaiter ForceAsync(this Task task) + { + return new ForceAsyncAwaiter(task); + } + } + + public struct ForceAsyncAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + + internal ForceAsyncAwaiter(Task task) { _task = task; } + + public ForceAsyncAwaiter GetAwaiter() { return this; } + + public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation + + public void GetResult() { _task.GetAwaiter().GetResult(); } + + public void OnCompleted(Action action) + { + _task.ConfigureAwait(false).GetAwaiter().OnCompleted(action); + } + + public void UnsafeOnCompleted(Action action) + { + _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action); + } + } + + public struct ForceAsyncAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + + internal ForceAsyncAwaiter(Task task) { _task = task; } + + public ForceAsyncAwaiter GetAwaiter() { return this; } + + public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation + + public T GetResult() { return _task.GetAwaiter().GetResult(); } + + public void OnCompleted(Action action) + { + _task.ConfigureAwait(false).GetAwaiter().OnCompleted(action); + } + + public void UnsafeOnCompleted(Action action) + { + _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index bc87ffedf2..15fe5dafbd 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -6,7 +6,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; @@ -17,6 +16,7 @@ using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Features; +using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; @@ -81,7 +81,9 @@ namespace Microsoft.AspNetCore.SignalR.Client _connection.Closed += Shutdown; } - public async Task StartAsync() + public async Task StartAsync() => await StartAsyncCore().ForceAsync(); + + private async Task StartAsyncCore() { var transferModeFeature = _connection.Features.Get(); if (transferModeFeature == null) @@ -122,7 +124,9 @@ namespace Microsoft.AspNetCore.SignalR.Client return new PassThroughEncoder(); } - public async Task DisposeAsync() + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); + + private async Task DisposeAsyncCore() { await _connection.DisposeAsync(); } @@ -141,14 +145,20 @@ namespace Microsoft.AspNetCore.SignalR.Client return channel; } - public async Task InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) + public async Task InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) => + await InvokeAsyncCore(methodName, returnType, cancellationToken, args).ForceAsync(); + + private async Task InvokeAsyncCore(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task); await InvokeCore(methodName, irq, args, nonBlocking: false); return await task; } - public Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args) + public async Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args) => + await SendAsyncCore(methodName, cancellationToken, args).ForceAsync(); + + private Task SendAsyncCore(string methodName, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, out _); return InvokeCore(methodName, irq, args, nonBlocking: true); diff --git a/src/Microsoft.AspNetCore.SignalR.Client/Microsoft.AspNetCore.SignalR.Client.csproj b/src/Microsoft.AspNetCore.SignalR.Client/Microsoft.AspNetCore.SignalR.Client.csproj index b2bd6dbd6e..df524170a9 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/Microsoft.AspNetCore.SignalR.Client.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Client/Microsoft.AspNetCore.SignalR.Client.csproj @@ -11,6 +11,10 @@ false + + + + diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 30f9041a83..29faf3d5f9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -88,7 +88,9 @@ namespace Microsoft.AspNetCore.Sockets.Client _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); } - public Task StartAsync() + public async Task StartAsync() => await StartAsyncCore().ForceAsync(); + + private Task StartAsyncCore() { if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) != ConnectionState.Initial) @@ -357,7 +359,10 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.EndReceive(_connectionId); } - public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken)) + public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default(CancellationToken)) => + await SendAsyncCore(data, cancellationToken).ForceAsync(); + + private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToken) { if (data == null) { @@ -389,7 +394,9 @@ namespace Microsoft.AspNetCore.Sockets.Client } } - public async Task DisposeAsync() + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); + + private async Task DisposeAsyncCore() { _logger.StoppingClient(_connectionId); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj index a8af7cb7d0..8e96477233 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj @@ -11,6 +11,10 @@ false + + + +