Replace ConnectionContext with HubConnectionContext (#629)

* Replace ConnectionContext with HubConnectionContext
- The SocketDelegate implementation owns the transport pipe,
it's a single producer single consumer model. SignalR needs to support
multiple producers so that broadcast, return values and sending to individual
connections works. This change introduces a multi producer channel that is used
by all producers to copy data to the transport safely. This will make the move
to pipelines easier.
This commit is contained in:
David Fowler 2017-07-03 17:44:28 -07:00 committed by GitHub
parent 652afa7023
commit f21f5039b2
16 changed files with 187 additions and 77 deletions

View File

@ -4,15 +4,15 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.SignalR;
namespace ChatSample namespace ChatSample
{ {
public interface IUserTracker<out THub> public interface IUserTracker<out THub>
{ {
Task<IEnumerable<UserDetails>> UsersOnline(); Task<IEnumerable<UserDetails>> UsersOnline();
Task AddUser(ConnectionContext connection, UserDetails userDetails); Task AddUser(HubConnectionContext connection, UserDetails userDetails);
Task RemoveUser(ConnectionContext connection); Task RemoveUser(HubConnectionContext connection);
event Action<UserDetails[]> UsersJoined; event Action<UserDetails[]> UsersJoined;
event Action<UserDetails[]> UsersLeft; event Action<UserDetails[]> UsersLeft;

View File

@ -3,14 +3,14 @@ using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.SignalR;
namespace ChatSample namespace ChatSample
{ {
public class InMemoryUserTracker<THub> : IUserTracker<THub> public class InMemoryUserTracker<THub> : IUserTracker<THub>
{ {
private readonly ConcurrentDictionary<ConnectionContext, UserDetails> _usersOnline private readonly ConcurrentDictionary<HubConnectionContext, UserDetails> _usersOnline
= new ConcurrentDictionary<ConnectionContext, UserDetails>(); = new ConcurrentDictionary<HubConnectionContext, UserDetails>();
public event Action<UserDetails[]> UsersJoined; public event Action<UserDetails[]> UsersJoined;
public event Action<UserDetails[]> UsersLeft; public event Action<UserDetails[]> UsersLeft;
@ -18,7 +18,7 @@ namespace ChatSample
public Task<IEnumerable<UserDetails>> UsersOnline() public Task<IEnumerable<UserDetails>> UsersOnline()
=> Task.FromResult(_usersOnline.Values.AsEnumerable()); => Task.FromResult(_usersOnline.Values.AsEnumerable());
public Task AddUser(ConnectionContext connection, UserDetails userDetails) public Task AddUser(HubConnectionContext connection, UserDetails userDetails)
{ {
_usersOnline.TryAdd(connection, userDetails); _usersOnline.TryAdd(connection, userDetails);
UsersJoined(new[] { userDetails }); UsersJoined(new[] { userDetails });
@ -26,7 +26,7 @@ namespace ChatSample
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task RemoveUser(ConnectionContext connection) public Task RemoveUser(HubConnectionContext connection)
{ {
if (_usersOnline.TryRemove(connection, out var userDetails)) if (_usersOnline.TryRemove(connection, out var userDetails))
{ {

View File

@ -36,7 +36,7 @@ namespace ChatSample
where THubLifetimeManager : HubLifetimeManager<THub> where THubLifetimeManager : HubLifetimeManager<THub>
where THub : HubWithPresence where THub : HubWithPresence
{ {
private readonly ConnectionList _connections = new ConnectionList(); private readonly HubConnectionList _connections = new HubConnectionList();
private readonly IUserTracker<THub> _userTracker; private readonly IUserTracker<THub> _userTracker;
private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly ILogger _logger; private readonly ILogger _logger;
@ -57,14 +57,14 @@ namespace ChatSample
_wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>(); _wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>();
} }
public override async Task OnConnectedAsync(ConnectionContext connection) public override async Task OnConnectedAsync(HubConnectionContext connection)
{ {
await _wrappedHubLifetimeManager.OnConnectedAsync(connection); await _wrappedHubLifetimeManager.OnConnectedAsync(connection);
_connections.Add(connection); _connections.Add(connection);
await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name)); await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name));
} }
public override async Task OnDisconnectedAsync(ConnectionContext connection) public override async Task OnDisconnectedAsync(HubConnectionContext connection)
{ {
await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection); await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection);
_connections.Remove(connection); _connections.Remove(connection);

View File

@ -9,8 +9,8 @@ using System.Net;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Redis; using Microsoft.AspNetCore.SignalR.Redis;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Newtonsoft.Json; using Newtonsoft.Json;
@ -129,7 +129,7 @@ namespace ChatSample
} }
} }
public async Task AddUser(ConnectionContext connection, UserDetails userDetails) public async Task AddUser(HubConnectionContext connection, UserDetails userDetails)
{ {
var key = GetUserRedisKey(connection); var key = GetUserRedisKey(connection);
var user = SerializeUser(connection); var user = SerializeUser(connection);
@ -156,7 +156,7 @@ namespace ChatSample
} }
} }
public async Task RemoveUser(ConnectionContext connection) public async Task RemoveUser(HubConnectionContext connection)
{ {
await _userSyncSempaphore.WaitAsync(); await _userSyncSempaphore.WaitAsync();
try try
@ -180,7 +180,7 @@ namespace ChatSample
} }
} }
private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}"; private static string GetUserRedisKey(HubConnectionContext connection) => $"user:{connection.ConnectionId}";
private static void Scan(object state) private static void Scan(object state)
{ {
@ -319,7 +319,7 @@ namespace ChatSample
} }
} }
private static string SerializeUser(ConnectionContext connection) => private static string SerializeUser(HubConnectionContext connection) =>
$"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}"; $"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}";
private static UserDetails DeserializerUser(string userJson) => private static UserDetails DeserializerUser(string userJson) =>

View File

@ -10,7 +10,6 @@ using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Newtonsoft.Json; using Newtonsoft.Json;
@ -22,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{ {
private const string RedisSubscriptionsMetadataName = "redis_subscriptions"; private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
private readonly ConnectionList _connections = new ConnectionList(); private readonly HubConnectionList _connections = new HubConnectionList();
// TODO: Investigate "memory leak" entries never get removed // TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>(); private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly ConnectionMultiplexer _redisServerConnection; private readonly ConnectionMultiplexer _redisServerConnection;
@ -128,7 +127,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await _bus.PublishAsync(channel, payload); await _bus.PublishAsync(channel, payload);
} }
public override Task OnConnectedAsync(ConnectionContext connection) public override Task OnConnectedAsync(HubConnectionContext connection)
{ {
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>()); var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
var connectionTask = Task.CompletedTask; var connectionTask = Task.CompletedTask;
@ -173,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(connectionTask, userTask); return Task.WhenAll(connectionTask, userTask);
} }
public override Task OnDisconnectedAsync(ConnectionContext connection) public override Task OnDisconnectedAsync(HubConnectionContext connection)
{ {
_connections.Remove(connection); _connections.Remove(connection);
@ -307,14 +306,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose(); _redisServerConnection.Dispose();
} }
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage)
{ {
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol); var data = connection.Protocol.WriteToArray(hubMessage);
var data = protocol.WriteToArray(hubMessage);
while (await connection.Transport.Out.WaitToWriteAsync()) while (await connection.Output.WaitToWriteAsync())
{ {
if (connection.Transport.Out.TryWrite(data)) if (connection.Output.TryWrite(data))
{ {
break; break;
} }
@ -363,7 +361,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private class GroupData private class GroupData
{ {
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public ConnectionList Connections = new ConnectionList(); public HubConnectionList Connections = new HubConnectionList();
} }
} }
} }

View File

@ -3,19 +3,16 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR namespace Microsoft.AspNetCore.SignalR
{ {
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub> public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{ {
private long _nextInvocationId = 0; private long _nextInvocationId = 0;
private readonly ConnectionList _connections = new ConnectionList(); private readonly HubConnectionList _connections = new HubConnectionList();
public override Task AddGroupAsync(string connectionId, string groupName) public override Task AddGroupAsync(string connectionId, string groupName)
{ {
@ -62,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR
return InvokeAllWhere(methodName, args, c => true); return InvokeAllWhere(methodName, args, c => true);
} }
private Task InvokeAllWhere(string methodName, object[] args, Func<ConnectionContext, bool> include) private Task InvokeAllWhere(string methodName, object[] args, Func<HubConnectionContext, bool> include)
{ {
var tasks = new List<Task>(_connections.Count); var tasks = new List<Task>(_connections.Count);
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
@ -107,26 +104,25 @@ namespace Microsoft.AspNetCore.SignalR
}); });
} }
public override Task OnConnectedAsync(ConnectionContext connection) public override Task OnConnectedAsync(HubConnectionContext connection)
{ {
_connections.Add(connection); _connections.Add(connection);
return Task.CompletedTask; return Task.CompletedTask;
} }
public override Task OnDisconnectedAsync(ConnectionContext connection) public override Task OnDisconnectedAsync(HubConnectionContext connection)
{ {
_connections.Remove(connection); _connections.Remove(connection);
return Task.CompletedTask; return Task.CompletedTask;
} }
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage)
{ {
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol); var payload = connection.Protocol.WriteToArray(hubMessage);
var payload = protocol.WriteToArray(hubMessage);
while (await connection.Transport.Out.WaitToWriteAsync()) while (await connection.Output.WaitToWriteAsync())
{ {
if (connection.Transport.Out.TryWrite(payload)) if (connection.Output.TryWrite(payload))
{ {
break; break;
} }

View File

@ -2,18 +2,17 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Security.Claims; using System.Security.Claims;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR namespace Microsoft.AspNetCore.SignalR
{ {
public class HubCallerContext public class HubCallerContext
{ {
public HubCallerContext(ConnectionContext connection) public HubCallerContext(HubConnectionContext connection)
{ {
Connection = connection; Connection = connection;
} }
public ConnectionContext Connection { get; } public HubConnectionContext Connection { get; }
public ClaimsPrincipal User => Connection.User; public ClaimsPrincipal User => Connection.User;

View File

@ -0,0 +1,36 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Claims;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionContext
{
private readonly WritableChannel<byte[]> _output;
private readonly ConnectionContext _connectionContext;
public HubConnectionContext(WritableChannel<byte[]> output, ConnectionContext connectionContext)
{
_output = output;
_connectionContext = connectionContext;
}
// Used by the HubEndPoint only
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
public virtual string ConnectionId => _connectionContext.ConnectionId;
public virtual ClaimsPrincipal User => _connectionContext.User;
public virtual ConnectionMetadata Metadata => _connectionContext.Metadata;
public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
public virtual WritableChannel<byte[]> Output => _output;
}
}

View File

@ -0,0 +1,52 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionList : IReadOnlyCollection<HubConnectionContext>
{
private readonly ConcurrentDictionary<string, HubConnectionContext> _connections = new ConcurrentDictionary<string, HubConnectionContext>();
public HubConnectionContext this[string connectionId]
{
get
{
if (_connections.TryGetValue(connectionId, out var connection))
{
return connection;
}
return null;
}
}
public int Count => _connections.Count;
public void Add(HubConnectionContext connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(HubConnectionContext connection)
{
_connections.TryRemove(connection.ConnectionId, out _);
}
public IEnumerator<HubConnectionContext> GetEnumerator()
{
foreach (var item in _connections)
{
yield return item.Value;
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

View File

@ -59,24 +59,54 @@ namespace Microsoft.AspNetCore.SignalR
public async Task OnConnectedAsync(ConnectionContext connection) public async Task OnConnectedAsync(ConnectionContext connection)
{ {
await ProcessNegotiate(connection); var output = Channel.CreateUnbounded<byte[]>();
var connectionContext = new HubConnectionContext(output, connection);
await ProcessNegotiate(connectionContext);
// Hubs support multiple producers so we set up this loop to copy
// data written to the HubConnectionContext's channel to the transport channel
async Task WriteToTransport()
{
while (await output.In.WaitToReadAsync())
{
while (output.In.TryRead(out var buffer))
{
while (await connection.Transport.Out.WaitToWriteAsync())
{
if (connection.Transport.Out.TryWrite(buffer))
{
break;
}
}
}
}
}
var writingOutputTask = WriteToTransport();
try try
{ {
await _lifetimeManager.OnConnectedAsync(connection); await _lifetimeManager.OnConnectedAsync(connectionContext);
await RunHubAsync(connection); await RunHubAsync(connectionContext);
} }
finally finally
{ {
await _lifetimeManager.OnDisconnectedAsync(connection); await _lifetimeManager.OnDisconnectedAsync(connectionContext);
// Nothing should be writing to the HubConnectionContext
output.Out.TryComplete();
// This should unwind once we complete the output
await writingOutputTask;
} }
} }
private async Task ProcessNegotiate(ConnectionContext connection) private async Task ProcessNegotiate(HubConnectionContext connection)
{ {
while (await connection.Transport.In.WaitToReadAsync()) while (await connection.Input.WaitToReadAsync())
{ {
while (connection.Transport.In.TryRead(out var buffer)) while (connection.Input.TryRead(out var buffer))
{ {
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage))
{ {
@ -92,7 +122,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task RunHubAsync(ConnectionContext connection) private async Task RunHubAsync(HubConnectionContext connection)
{ {
await HubOnConnectedAsync(connection); await HubOnConnectedAsync(connection);
@ -110,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR
await HubOnDisconnectedAsync(connection, null); await HubOnDisconnectedAsync(connection, null);
} }
private async Task HubOnConnectedAsync(ConnectionContext connection) private async Task HubOnConnectedAsync(HubConnectionContext connection)
{ {
try try
{ {
@ -136,7 +166,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception) private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Exception exception)
{ {
try try
{ {
@ -162,7 +192,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task DispatchMessagesAsync(ConnectionContext connection) private async Task DispatchMessagesAsync(HubConnectionContext connection)
{ {
// We use these for error handling. Since we dispatch multiple hub invocations // We use these for error handling. Since we dispatch multiple hub invocations
// in parallel, we need a way to communicate failure back to the main processing loop. The // in parallel, we need a way to communicate failure back to the main processing loop. The
@ -174,9 +204,9 @@ namespace Microsoft.AspNetCore.SignalR
try try
{ {
while (await connection.Transport.In.WaitToReadAsync(cts.Token)) while (await connection.Input.WaitToReadAsync(cts.Token))
{ {
while (connection.Transport.In.TryRead(out var buffer)) while (connection.Input.TryRead(out var buffer))
{ {
if (protocol.TryParseMessages(buffer, this, out var hubMessages)) if (protocol.TryParseMessages(buffer, this, out var hubMessages))
{ {
@ -212,7 +242,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task ProcessInvocation(ConnectionContext connection, private async Task ProcessInvocation(HubConnectionContext connection,
IHubProtocol protocol, IHubProtocol protocol,
InvocationMessage invocationMessage, InvocationMessage invocationMessage,
CancellationTokenSource dispatcherCancellation, CancellationTokenSource dispatcherCancellation,
@ -234,7 +264,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) private async Task Execute(HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
{ {
if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor)) if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor))
{ {
@ -248,13 +278,13 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) private async Task SendMessageAsync(HubConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
{ {
var payload = protocol.WriteToArray(hubMessage); var payload = protocol.WriteToArray(hubMessage);
while (await connection.Transport.Out.WaitToWriteAsync()) while (await connection.Output.WaitToWriteAsync())
{ {
if (connection.Transport.Out.TryWrite(payload)) if (connection.Output.TryWrite(payload))
{ {
return; return;
} }
@ -265,7 +295,7 @@ namespace Microsoft.AspNetCore.SignalR
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
} }
private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
{ {
var methodExecutor = descriptor.MethodExecutor; var methodExecutor = descriptor.MethodExecutor;
@ -341,7 +371,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private void InitializeHub(THub hub, ConnectionContext connection) private void InitializeHub(THub hub, HubConnectionContext connection)
{ {
hub.Clients = _hubContext.Clients; hub.Clients = _hubContext.Clients;
hub.Context = new HubCallerContext(connection); hub.Context = new HubCallerContext(connection);
@ -363,7 +393,7 @@ namespace Microsoft.AspNetCore.SignalR
} }
} }
private async Task StreamResultsAsync(string invocationId, ConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator<object> enumerator) private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator<object> enumerator)
{ {
// TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481 // TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481
try try

View File

@ -2,15 +2,14 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR namespace Microsoft.AspNetCore.SignalR
{ {
public abstract class HubLifetimeManager<THub> public abstract class HubLifetimeManager<THub>
{ {
public abstract Task OnConnectedAsync(ConnectionContext connection); public abstract Task OnConnectedAsync(HubConnectionContext connection);
public abstract Task OnDisconnectedAsync(ConnectionContext connection); public abstract Task OnDisconnectedAsync(HubConnectionContext connection);
public abstract Task InvokeAllAsync(string methodName, object[] args); public abstract Task InvokeAllAsync(string methodName, object[] args);

View File

@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{ {
public class DefaultHubProtocolResolver : IHubProtocolResolver public class DefaultHubProtocolResolver : IHubProtocolResolver
{ {
public IHubProtocol GetProtocol(string protocolName, ConnectionContext connection) public IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection)
{ {
switch (protocolName?.ToLowerInvariant()) switch (protocolName?.ToLowerInvariant())
{ {

View File

@ -2,12 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Internal namespace Microsoft.AspNetCore.SignalR.Internal
{ {
public interface IHubProtocolResolver public interface IHubProtocolResolver
{ {
IHubProtocol GetProtocol(string protocolName, ConnectionContext connection); IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection);
} }
} }

View File

@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{ {
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>(); var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>())) .Setup(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>(); var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
@ -64,8 +64,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
client.Dispose(); client.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
// No hubs should be created since the connection is terminated // No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
@ -91,8 +91,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask); var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message); Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
} }
} }
@ -115,8 +115,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask); var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message); Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once); mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
} }
} }

View File

@ -3,6 +3,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets;
@ -18,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
[MemberData(nameof(HubProtocols))] [MemberData(nameof(HubProtocols))]
public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol)
{ {
var mockConnection = new Mock<ConnectionContext>(); var mockConnection = new Mock<HubConnectionContext>(Channel.CreateUnbounded<byte[]>().Out, new Mock<ConnectionContext>().Object);
Assert.IsType( Assert.IsType(
protocol.GetType(), protocol.GetType(),
new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object)); new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object));
@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
[InlineData("dummy")] [InlineData("dummy")]
public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName) public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName)
{ {
var mockConnection = new Mock<ConnectionContext>(); var mockConnection = new Mock<HubConnectionContext>(Channel.CreateUnbounded<byte[]>().Out, new Mock<ConnectionContext>().Object);
var exception = Assert.Throws<NotSupportedException>( var exception = Assert.Throws<NotSupportedException>(
() => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object)); () => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object));