diff --git a/SignalR.sln b/SignalR.sln
index 6f89370fc0..19a8a12b1e 100644
--- a/SignalR.sln
+++ b/SignalR.sln
@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
-VisualStudioVersion = 15.0.26730.0
+VisualStudioVersion = 15.0.26823.1
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject
diff --git a/src/Common/ForceAsyncAwaiter.cs b/src/Common/ForceAsyncAwaiter.cs
new file mode 100644
index 0000000000..b52aa8a9a6
--- /dev/null
+++ b/src/Common/ForceAsyncAwaiter.cs
@@ -0,0 +1,75 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Runtime.CompilerServices;
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.Sockets.Internal
+{
+ public static class ForceAsyncTaskExtensions
+ {
+ ///
+ /// Returns an awaitable/awaiter that will ensure the continuation is executed
+ /// asynchronously on the thread pool, even if the task is already completed
+ /// by the time the await occurs. Effectively, it is equivalent to awaiting
+ /// with ConfigureAwait(false) and then queuing the continuation with Task.Run,
+ /// but it avoids the extra hop if the continuation already executed asynchronously.
+ ///
+ public static ForceAsyncAwaiter ForceAsync(this Task task)
+ {
+ return new ForceAsyncAwaiter(task);
+ }
+
+ public static ForceAsyncAwaiter ForceAsync(this Task task)
+ {
+ return new ForceAsyncAwaiter(task);
+ }
+ }
+
+ public struct ForceAsyncAwaiter : ICriticalNotifyCompletion
+ {
+ private readonly Task _task;
+
+ internal ForceAsyncAwaiter(Task task) { _task = task; }
+
+ public ForceAsyncAwaiter GetAwaiter() { return this; }
+
+ public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
+
+ public void GetResult() { _task.GetAwaiter().GetResult(); }
+
+ public void OnCompleted(Action action)
+ {
+ _task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
+ }
+
+ public void UnsafeOnCompleted(Action action)
+ {
+ _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
+ }
+ }
+
+ public struct ForceAsyncAwaiter : ICriticalNotifyCompletion
+ {
+ private readonly Task _task;
+
+ internal ForceAsyncAwaiter(Task task) { _task = task; }
+
+ public ForceAsyncAwaiter GetAwaiter() { return this; }
+
+ public bool IsCompleted { get { return false; } } // the purpose of this type is to always force a continuation
+
+ public T GetResult() { return _task.GetAwaiter().GetResult(); }
+
+ public void OnCompleted(Action action)
+ {
+ _task.ConfigureAwait(false).GetAwaiter().OnCompleted(action);
+ }
+
+ public void UnsafeOnCompleted(Action action)
+ {
+ _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(action);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
index bc87ffedf2..15fe5dafbd 100644
--- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs
@@ -6,7 +6,6 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
-using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
@@ -17,6 +16,7 @@ using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Features;
+using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Newtonsoft.Json;
@@ -81,7 +81,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
_connection.Closed += Shutdown;
}
- public async Task StartAsync()
+ public async Task StartAsync() => await StartAsyncCore().ForceAsync();
+
+ private async Task StartAsyncCore()
{
var transferModeFeature = _connection.Features.Get();
if (transferModeFeature == null)
@@ -122,7 +124,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
return new PassThroughEncoder();
}
- public async Task DisposeAsync()
+ public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
+
+ private async Task DisposeAsyncCore()
{
await _connection.DisposeAsync();
}
@@ -141,14 +145,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
return channel;
}
- public async Task