diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionHandler.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionHandler.cs index 9c7ff24e4f..668016c25e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionHandler.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionHandler.cs @@ -164,9 +164,6 @@ namespace Microsoft.AspNetCore.SignalR private async Task DispatchMessagesAsync(HubConnectionContext connection) { - // Since we dispatch multiple hub invocations in parallel, we need a way to communicate failure back to the main processing loop. - // This is done by aborting the connection. - try { var input = connection.Input; @@ -182,9 +179,9 @@ namespace Microsoft.AspNetCore.SignalR { while (protocol.TryParseMessage(ref buffer, _dispatcher, out var message)) { - // Don't wait on the result of execution, continue processing other - // incoming messages on this connection. - _ = _dispatcher.DispatchMessageAsync(connection, message); + // Messages are dispatched sequentially and will block other messages from being processed until they complete. + // Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run. + await _dispatcher.DispatchMessageAsync(connection, message); } } else if (result.IsCompleted) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index 66046ebaa5..29bbbed579 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -170,7 +170,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal { var methodExecutor = descriptor.MethodExecutor; - using (var scope = _serviceScopeFactory.CreateScope()) + var disposeScope = true; + var scope = _serviceScopeFactory.CreateScope(); + IHubActivator hubActivator = null; + THub hub = null; + try { if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies)) { @@ -185,8 +189,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal return; } - var hubActivator = scope.ServiceProvider.GetRequiredService>(); - var hub = hubActivator.Create(); + hubActivator = scope.ServiceProvider.GetRequiredService>(); + hub = hubActivator.Create(); try { @@ -205,8 +209,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal return; } + disposeScope = false; Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, streamCts); + // Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run + _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts); } // Non-empty/null InvocationId ==> Blocking invocation that needs a response else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -227,51 +233,60 @@ namespace Microsoft.AspNetCore.SignalR.Internal await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); } - finally + } + finally + { + if (disposeScope) { - hubActivator.Release(hub); + hubActivator?.Release(hub); + scope.Dispose(); } } } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, CancellationTokenSource streamCts) + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, IServiceScope scope, IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts) { string error = null; - try + using (scope) { - while (await enumerator.MoveNextAsync()) + try { - // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); + while (await enumerator.MoveNextAsync()) + { + // Send the stream item + await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); + } } - } - catch (ChannelClosedException ex) - { - // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method - error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors); - } - catch (Exception ex) - { - // If the streaming method was canceled we don't want to send a HubException message - this is not an error case - if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) - && cts.IsCancellationRequested)) + catch (ChannelClosedException ex) { - error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); + // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method + error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors); } - } - finally - { - (enumerator as IDisposable)?.Dispose(); - - // Dispose the linked CTS for the stream. - streamCts.Dispose(); - - await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); - - if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) + catch (Exception ex) { - cts.Dispose(); + // If the streaming method was canceled we don't want to send a HubException message - this is not an error case + if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) + && cts.IsCancellationRequested)) + { + error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); + } + } + finally + { + (enumerator as IDisposable)?.Dispose(); + + hubActivator.Release(hub); + + // Dispose the linked CTS for the stream. + streamCts.Dispose(); + + await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); + + if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) + { + cts.Dispose(); + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 0de1b849bf..514d069757 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Generic; -using System.Reactive.Linq; -using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; @@ -431,7 +429,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests public class StreamingHub : TestHub { - public ChannelReader CounterChannel(int count) { var channel = Channel.CreateUnbounded(); @@ -471,6 +468,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests channel.Writer.TryComplete(new Exception("Exception from channel")); return channel.Reader; } + + public int NonStream() + { + return 42; + } } public class SimpleHub : Hub @@ -491,6 +493,42 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + public class LongRunningHub : Hub + { + private TcsService _tcsService; + + public LongRunningHub(TcsService tcsService) + { + _tcsService = tcsService; + } + + public async Task LongRunningMethod() + { + _tcsService.StartedMethod.TrySetResult(null); + await _tcsService.EndMethod.Task; + return 12; + } + + public async Task> LongRunningStream() + { + _tcsService.StartedMethod.TrySetResult(null); + await _tcsService.EndMethod.Task; + // Never ending stream + return Channel.CreateUnbounded().Reader; + } + + public int SimpleMethod() + { + return 21; + } + } + + public class TcsService + { + public TaskCompletionSource StartedMethod = new TaskCompletionSource(); + public TaskCompletionSource EndMethod = new TaskCompletionSource(); + } + public interface ITypedHubClient { Task Send(string message); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 3ea5d2e4f7..900df9778a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -13,6 +13,7 @@ using MessagePack.Formatters; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Connections.Features; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; @@ -1982,6 +1983,153 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task StreamingInvocationsDoNotBlockOtherInvocations() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient(new JsonHubProtocol())) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + // Blocking streaming invocation to test that other invocations can still run + await client.SendHubMessageAsync(new StreamInvocationMessage("1", nameof(StreamingHub.BlockingStream), Array.Empty())).OrTimeout(); + + var completion = await client.InvokeAsync(nameof(StreamingHub.NonStream)).OrTimeout(); + Assert.Equal(42L, completion.Result); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task InvocationsRunInOrder() + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }); + var connectionHandler = serviceProvider.GetService>(); + + // Because we use PipeScheduler.Inline the hub invocations will run inline until they wait, which happens inside the LongRunningMethod call + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + // Long running hub invocation to test that other invocations will not run until it is completed + await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout(); + // Wait for the long running method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + // Invoke another hub method which will wait for the first method to finish + await client.SendInvocationAsync(nameof(LongRunningHub.SimpleMethod), nonBlocking: false).OrTimeout(); + // Both invocations should be waiting now + Assert.Null(client.TryRead()); + + // Release the long running hub method + tcsService.EndMethod.TrySetResult(null); + + // Long running hub method result + var firstResult = await client.ReadAsync().OrTimeout(); + + var longRunningCompletion = Assert.IsType(firstResult); + Assert.Equal(12L, longRunningCompletion.Result); + + // simple hub method result + var secondResult = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(secondResult); + Assert.Equal(21L, simpleCompletion.Result); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task StreamInvocationsBlockOtherInvocationsUntilTheyStartStreaming() + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + }); + var connectionHandler = serviceProvider.GetService>(); + + // Because we use PipeScheduler.Inline the hub invocations will run inline until they wait, which happens inside the LongRunningMethod call + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + // Long running hub invocation to test that other invocations will not run until it is completed + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream)).OrTimeout(); + // Wait for the long running method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + // Invoke another hub method which will wait for the first method to finish + await client.SendInvocationAsync(nameof(LongRunningHub.SimpleMethod), nonBlocking: false).OrTimeout(); + // Both invocations should be waiting now + Assert.Null(client.TryRead()); + + // Release the long running hub method + tcsService.EndMethod.TrySetResult(null); + + // simple hub method result + var result = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(result); + Assert.Equal(21L, simpleCompletion.Result); + + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; + + // OnConnectedAsync and SimpleMethod hubs have been disposed at this point + Assert.Equal(2, hubActivator.ReleaseCount); + + await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout(); + + // Completion message for canceled Stream + await client.ReadAsync().OrTimeout(); + + // Stream method is now disposed + Assert.Equal(3, hubActivator.ReleaseCount); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + private class CustomHubActivator : IHubActivator where THub : Hub + { + public int ReleaseCount; + private IServiceProvider _serviceProvider; + + public CustomHubActivator(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public THub Create() + { + return new DefaultHubActivator(_serviceProvider).Create(); + } + + public void Release(THub hub) + { + ReleaseCount++; + hub.Dispose(); + } + } + public static IEnumerable HubTypes() { yield return new[] { typeof(DynamicTestHub) };