Sorting out exceptions

Making sure that OnConnected/OnDisconnected events are invoked correctly (e.g. if invoking OnDisconnectedAsync on hub threw we would not call OnDisconnectedAsync on lifetime manager and therefore we would continue to use/track connections that were already closed)
This commit is contained in:
moozzyk 2016-12-30 15:28:49 -08:00
parent 80c5f9be0e
commit 217f707456
5 changed files with 166 additions and 17 deletions

View File

@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Internal;

View File

@ -56,11 +56,40 @@ namespace Microsoft.AspNetCore.SignalR
{
// TODO: Dispatch from the caller
await Task.Yield();
Exception exception = null;
try
{
await _lifetimeManager.OnConnectedAsync(connection);
await RunHubAsync(connection);
}
finally
{
await _lifetimeManager.OnDisconnectedAsync(connection);
}
}
private async Task RunHubAsync(Connection connection)
{
await HubOnConnectedAsync(connection);
try
{
await DispatchMessagesAsync(connection);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Error when processing requests.");
await HubOnDisconnectedAsync(connection, ex);
throw;
}
await HubOnDisconnectedAsync(connection, null);
}
private async Task HubOnConnectedAsync(Connection connection)
{
try
{
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub, TClient>>();
@ -75,17 +104,17 @@ namespace Microsoft.AspNetCore.SignalR
hubActivator.Release(hub);
}
}
await DispatchMessagesAsync(connection);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Error when processing requests.");
exception = ex;
connection.Channel.Input.Complete(exception);
connection.Channel.Output.Complete(exception);
_logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub.");
throw;
}
finally
}
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
{
try
{
using (var scope = _serviceScopeFactory.CreateScope())
{
@ -101,8 +130,11 @@ namespace Microsoft.AspNetCore.SignalR
hubActivator.Release(hub);
}
}
await _lifetimeManager.OnDisconnectedAsync(connection);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Error when invoking OnDisconnectedAsync on hub.");
throw;
}
}

View File

@ -1,8 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;

View File

@ -111,9 +111,16 @@ namespace Microsoft.AspNetCore.Sockets
// REVIEW: This is super gross, this all needs to be cleaned up...
state.Close = async () =>
{
state.Connection.Channel.Dispose();
try
{
await endpointTask;
}
catch
{
// possibly invoked on a ThreadPool thread
}
await endpointTask;
state.Connection.Channel.Dispose();
};
endpointTask = endpoint.OnConnectedAsync(state.Connection);
@ -130,6 +137,11 @@ namespace Microsoft.AspNetCore.Sockets
if (resultTask == endpointTask)
{
// Notify the long polling transport to end
if (endpointTask.IsFaulted)
{
state.Connection.Channel.Input.Complete(endpointTask.Exception.InnerException);
state.Connection.Channel.Output.Complete(endpointTask.Exception.InnerException);
}
state.Connection.Channel.Dispose();
await transportTask;

View File

@ -9,7 +9,6 @@ using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Moq;
using Moq.Protected;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
@ -64,18 +63,127 @@ namespace Microsoft.AspNetCore.SignalR.Tests
connectionWrapper.Connection.Channel.Dispose();
await endPointTask;
// InvalidCastException because the payload is not a JObject
// which is expected by the formatter
await Assert.ThrowsAsync<InvalidCastException>(async () => await endPointTask);
Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull<Exception>()), Times.Once());
}
}
[Fact]
public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<Connection>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
services.AddSingleton(mockHubActivator.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<Hub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var exception =
await Assert.ThrowsAsync<InvalidOperationException>(
async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection));
Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message);
connectionWrapper.Connection.Channel.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
// No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
}
}
[Fact]
public async Task HubOnDisconnectedAsyncCalledIfHubOnConnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<OnConnectedThrowsHub>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<OnConnectedThrowsHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.Connection.Channel.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
}
}
[Fact]
public async Task LifetimeManagerOnDisconnectedAsyncCalledIfHubOnDisconnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<OnDisconnectedThrowsHub>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<OnDisconnectedThrowsHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.Connection.Channel.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
}
}
private static Type GetEndPointType(Type hubType)
{
var endPointType = typeof(HubEndPoint<>);
return endPointType.MakeGenericType(hubType);
}
private static Type GetGenericType(Type genericType, Type hubType)
{
return genericType.MakeGenericType(hubType);
}
public class OnConnectedThrowsHub : Hub
{
public override Task OnConnectedAsync()
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnConnected failed."));
return tcs.Task;
}
}
public class OnDisconnectedThrowsHub : Hub
{
public override Task OnDisconnectedAsync(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed."));
return tcs.Task;
}
}
private class TestHub : Hub
{
private TrackDispose _trackDispose;