diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 932a1bc665..4882a1259b 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.IO.Pipelines; -using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Internal; diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 3db914b5d6..d97118d6a3 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -56,11 +56,40 @@ namespace Microsoft.AspNetCore.SignalR { // TODO: Dispatch from the caller await Task.Yield(); - Exception exception = null; + try { await _lifetimeManager.OnConnectedAsync(connection); + await RunHubAsync(connection); + } + finally + { + await _lifetimeManager.OnDisconnectedAsync(connection); + } + } + private async Task RunHubAsync(Connection connection) + { + await HubOnConnectedAsync(connection); + + try + { + await DispatchMessagesAsync(connection); + } + catch (Exception ex) + { + _logger.LogError(0, ex, "Error when processing requests."); + await HubOnDisconnectedAsync(connection, ex); + throw; + } + + await HubOnDisconnectedAsync(connection, null); + } + + private async Task HubOnConnectedAsync(Connection connection) + { + try + { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); @@ -75,17 +104,17 @@ namespace Microsoft.AspNetCore.SignalR hubActivator.Release(hub); } } - - await DispatchMessagesAsync(connection); } catch (Exception ex) { - _logger.LogError(0, ex, "Error when processing requests."); - exception = ex; - connection.Channel.Input.Complete(exception); - connection.Channel.Output.Complete(exception); + _logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub."); + throw; } - finally + } + + private async Task HubOnDisconnectedAsync(Connection connection, Exception exception) + { + try { using (var scope = _serviceScopeFactory.CreateScope()) { @@ -101,8 +130,11 @@ namespace Microsoft.AspNetCore.SignalR hubActivator.Release(hub); } } - - await _lifetimeManager.OnDisconnectedAsync(connection); + } + catch (Exception ex) + { + _logger.LogError(0, ex, "Error when invoking OnDisconnectedAsync on hub."); + throw; } } diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index 9659d643c9..dfed4a80ff 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -1,8 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index a69bf835ab..421ee4cd68 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -111,9 +111,16 @@ namespace Microsoft.AspNetCore.Sockets // REVIEW: This is super gross, this all needs to be cleaned up... state.Close = async () => { - state.Connection.Channel.Dispose(); + try + { + await endpointTask; + } + catch + { + // possibly invoked on a ThreadPool thread + } - await endpointTask; + state.Connection.Channel.Dispose(); }; endpointTask = endpoint.OnConnectedAsync(state.Connection); @@ -130,6 +137,11 @@ namespace Microsoft.AspNetCore.Sockets if (resultTask == endpointTask) { // Notify the long polling transport to end + if (endpointTask.IsFaulted) + { + state.Connection.Channel.Input.Complete(endpointTask.Exception.InnerException); + state.Connection.Channel.Output.Complete(endpointTask.Exception.InnerException); + } state.Connection.Channel.Dispose(); await transportTask; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 9d47059246..27e1fedaa2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -9,7 +9,6 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Moq; -using Moq.Protected; using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests @@ -64,18 +63,127 @@ namespace Microsoft.AspNetCore.SignalR.Tests connectionWrapper.Connection.Channel.Dispose(); - await endPointTask; + // InvalidCastException because the payload is not a JObject + // which is expected by the formatter + await Assert.ThrowsAsync(async () => await endPointTask); Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull()), Times.Once()); } } + [Fact] + public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows() + { + var mockLifetimeManager = new Mock>(); + mockLifetimeManager + .Setup(m => m.OnConnectedAsync(It.IsAny())) + .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); + var mockHubActivator = new Mock>(); + + var serviceProvider = CreateServiceProvider(services => + { + services.AddSingleton(mockLifetimeManager.Object); + services.AddSingleton(mockHubActivator.Object); + }); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var exception = + await Assert.ThrowsAsync( + async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection)); + Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message); + + connectionWrapper.Connection.Channel.Dispose(); + + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + // No hubs should be created since the connection is terminated + mockHubActivator.Verify(m => m.Create(), Times.Never); + mockHubActivator.Verify(m => m.Release(It.IsAny()), Times.Never); + } + } + + [Fact] + public async Task HubOnDisconnectedAsyncCalledIfHubOnConnectedAsyncThrows() + { + var mockLifetimeManager = new Mock>(); + var serviceProvider = CreateServiceProvider(services => + { + services.AddSingleton(mockLifetimeManager.Object); + }); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + connectionWrapper.Connection.Channel.Dispose(); + + var exception = await Assert.ThrowsAsync(async () => await endPointTask); + Assert.Equal("Hub OnConnected failed.", exception.Message); + + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + } + } + + [Fact] + public async Task LifetimeManagerOnDisconnectedAsyncCalledIfHubOnDisconnectedAsyncThrows() + { + var mockLifetimeManager = new Mock>(); + var serviceProvider = CreateServiceProvider(services => + { + services.AddSingleton(mockLifetimeManager.Object); + }); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + connectionWrapper.Connection.Channel.Dispose(); + + var exception = await Assert.ThrowsAsync(async () => await endPointTask); + Assert.Equal("Hub OnDisconnected failed.", exception.Message); + + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + } + } + private static Type GetEndPointType(Type hubType) { var endPointType = typeof(HubEndPoint<>); return endPointType.MakeGenericType(hubType); } + private static Type GetGenericType(Type genericType, Type hubType) + { + return genericType.MakeGenericType(hubType); + } + + public class OnConnectedThrowsHub : Hub + { + public override Task OnConnectedAsync() + { + var tcs = new TaskCompletionSource(); + tcs.SetException(new InvalidOperationException("Hub OnConnected failed.")); + return tcs.Task; + } + } + + public class OnDisconnectedThrowsHub : Hub + { + public override Task OnDisconnectedAsync(Exception exception) + { + var tcs = new TaskCompletionSource(); + tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed.")); + return tcs.Task; + } + } + private class TestHub : Hub { private TrackDispose _trackDispose;