aspnetcore/test/Microsoft.AspNetCore.Signal.../HubEndpointTests.cs

253 lines
9.5 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Security.Claims;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Moq;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class HubEndpointTests
{
[Fact]
public async Task HubsAreDisposed()
{
var trackDispose = new TrackDispose();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(trackDispose));
var endPoint = serviceProvider.GetService<HubEndPoint<TestHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
await connectionWrapper.ApplicationStartedReading;
// kill the connection
connectionWrapper.ConnectionState.Dispose();
await endPointTask;
Assert.Equal(2, trackDispose.DisposeCount);
}
}
[Fact]
public async Task OnDisconnectedCalledWithExceptionIfHubMethodNotFound()
{
var hub = Mock.Of<Hub>();
var endPointType = GetEndPointType(hub.GetType());
var serviceProvider = CreateServiceProvider(s =>
{
s.AddSingleton(endPointType);
s.AddTransient(hub.GetType(), sp => hub);
});
dynamic endPoint = serviceProvider.GetService(endPointType);
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
await connectionWrapper.ApplicationStartedReading;
var buffer = connectionWrapper.ConnectionState.Application.Output.Alloc();
buffer.Write(Encoding.UTF8.GetBytes("0xdeadbeef"));
await buffer.FlushAsync();
connectionWrapper.Dispose();
// InvalidCastException because the payload is not a JObject
// which is expected by the formatter
await Assert.ThrowsAsync<InvalidCastException>(async () => await endPointTask);
Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull<Exception>()), Times.Once());
}
}
[Fact]
public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
services.AddSingleton(mockHubActivator.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<Hub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var exception =
await Assert.ThrowsAsync<InvalidOperationException>(
async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection));
Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message);
connectionWrapper.ConnectionState.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
// No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
}
}
[Fact]
public async Task HubOnDisconnectedAsyncCalledIfHubOnConnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<OnConnectedThrowsHub>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<OnConnectedThrowsHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.ConnectionState.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
}
}
[Fact]
public async Task LifetimeManagerOnDisconnectedAsyncCalledIfHubOnDisconnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<OnDisconnectedThrowsHub>>();
var serviceProvider = CreateServiceProvider(services =>
{
services.AddSingleton(mockLifetimeManager.Object);
});
var endPoint = serviceProvider.GetService<HubEndPoint<OnDisconnectedThrowsHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.ConnectionState.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
}
}
private static Type GetEndPointType(Type hubType)
{
var endPointType = typeof(HubEndPoint<>);
return endPointType.MakeGenericType(hubType);
}
private static Type GetGenericType(Type genericType, Type hubType)
{
return genericType.MakeGenericType(hubType);
}
public class OnConnectedThrowsHub : Hub
{
public override Task OnConnectedAsync()
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnConnected failed."));
return tcs.Task;
}
}
public class OnDisconnectedThrowsHub : Hub
{
public override Task OnDisconnectedAsync(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed."));
return tcs.Task;
}
}
private class TestHub : Hub
{
private TrackDispose _trackDispose;
public TestHub(TrackDispose trackDispose)
{
_trackDispose = trackDispose;
}
protected override void Dispose(bool dispose)
{
if (dispose)
{
_trackDispose.DisposeCount++;
}
}
}
private class TrackDispose
{
public int DisposeCount = 0;
}
private IServiceProvider CreateServiceProvider(Action<ServiceCollection> addServices = null)
{
var services = new ServiceCollection();
services.AddOptions()
.AddLogging()
.AddSignalR();
addServices?.Invoke(services);
return services.BuildServiceProvider();
}
private class ConnectionWrapper : IDisposable
{
private PipelineFactory _factory;
public StreamingConnectionState ConnectionState;
public StreamingConnection Connection => ConnectionState.Connection;
// Still kinda gross...
public Task ApplicationStartedReading => ((PipelineReaderWriter)Connection.Transport.Input).ReadingStarted;
public ConnectionWrapper(string format = "json")
{
_factory = new PipelineFactory();
var connectionManager = new ConnectionManager(_factory);
ConnectionState = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming);
ConnectionState.Connection.Metadata["formatType"] = format;
ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity());
}
public void Dispose()
{
ConnectionState.Dispose();
_factory.Dispose();
}
}
}
}