Replace ConnectionContext with HubConnectionContext (#629)
* Replace ConnectionContext with HubConnectionContext - The SocketDelegate implementation owns the transport pipe, it's a single producer single consumer model. SignalR needs to support multiple producers so that broadcast, return values and sending to individual connections works. This change introduces a multi producer channel that is used by all producers to copy data to the transport safely. This will make the move to pipelines easier.
This commit is contained in:
parent
652afa7023
commit
f21f5039b2
|
|
@ -4,15 +4,15 @@
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
using Microsoft.AspNetCore.SignalR;
|
||||||
|
|
||||||
namespace ChatSample
|
namespace ChatSample
|
||||||
{
|
{
|
||||||
public interface IUserTracker<out THub>
|
public interface IUserTracker<out THub>
|
||||||
{
|
{
|
||||||
Task<IEnumerable<UserDetails>> UsersOnline();
|
Task<IEnumerable<UserDetails>> UsersOnline();
|
||||||
Task AddUser(ConnectionContext connection, UserDetails userDetails);
|
Task AddUser(HubConnectionContext connection, UserDetails userDetails);
|
||||||
Task RemoveUser(ConnectionContext connection);
|
Task RemoveUser(HubConnectionContext connection);
|
||||||
|
|
||||||
event Action<UserDetails[]> UsersJoined;
|
event Action<UserDetails[]> UsersJoined;
|
||||||
event Action<UserDetails[]> UsersLeft;
|
event Action<UserDetails[]> UsersLeft;
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,14 @@ using System.Collections.Concurrent;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
using Microsoft.AspNetCore.SignalR;
|
||||||
|
|
||||||
namespace ChatSample
|
namespace ChatSample
|
||||||
{
|
{
|
||||||
public class InMemoryUserTracker<THub> : IUserTracker<THub>
|
public class InMemoryUserTracker<THub> : IUserTracker<THub>
|
||||||
{
|
{
|
||||||
private readonly ConcurrentDictionary<ConnectionContext, UserDetails> _usersOnline
|
private readonly ConcurrentDictionary<HubConnectionContext, UserDetails> _usersOnline
|
||||||
= new ConcurrentDictionary<ConnectionContext, UserDetails>();
|
= new ConcurrentDictionary<HubConnectionContext, UserDetails>();
|
||||||
|
|
||||||
public event Action<UserDetails[]> UsersJoined;
|
public event Action<UserDetails[]> UsersJoined;
|
||||||
public event Action<UserDetails[]> UsersLeft;
|
public event Action<UserDetails[]> UsersLeft;
|
||||||
|
|
@ -18,7 +18,7 @@ namespace ChatSample
|
||||||
public Task<IEnumerable<UserDetails>> UsersOnline()
|
public Task<IEnumerable<UserDetails>> UsersOnline()
|
||||||
=> Task.FromResult(_usersOnline.Values.AsEnumerable());
|
=> Task.FromResult(_usersOnline.Values.AsEnumerable());
|
||||||
|
|
||||||
public Task AddUser(ConnectionContext connection, UserDetails userDetails)
|
public Task AddUser(HubConnectionContext connection, UserDetails userDetails)
|
||||||
{
|
{
|
||||||
_usersOnline.TryAdd(connection, userDetails);
|
_usersOnline.TryAdd(connection, userDetails);
|
||||||
UsersJoined(new[] { userDetails });
|
UsersJoined(new[] { userDetails });
|
||||||
|
|
@ -26,7 +26,7 @@ namespace ChatSample
|
||||||
return Task.CompletedTask;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Task RemoveUser(ConnectionContext connection)
|
public Task RemoveUser(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
if (_usersOnline.TryRemove(connection, out var userDetails))
|
if (_usersOnline.TryRemove(connection, out var userDetails))
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ namespace ChatSample
|
||||||
where THubLifetimeManager : HubLifetimeManager<THub>
|
where THubLifetimeManager : HubLifetimeManager<THub>
|
||||||
where THub : HubWithPresence
|
where THub : HubWithPresence
|
||||||
{
|
{
|
||||||
private readonly ConnectionList _connections = new ConnectionList();
|
private readonly HubConnectionList _connections = new HubConnectionList();
|
||||||
private readonly IUserTracker<THub> _userTracker;
|
private readonly IUserTracker<THub> _userTracker;
|
||||||
private readonly IServiceScopeFactory _serviceScopeFactory;
|
private readonly IServiceScopeFactory _serviceScopeFactory;
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
|
|
@ -57,14 +57,14 @@ namespace ChatSample
|
||||||
_wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>();
|
_wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
public override async Task OnConnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
await _wrappedHubLifetimeManager.OnConnectedAsync(connection);
|
await _wrappedHubLifetimeManager.OnConnectedAsync(connection);
|
||||||
_connections.Add(connection);
|
_connections.Add(connection);
|
||||||
await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name));
|
await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name));
|
||||||
}
|
}
|
||||||
|
|
||||||
public override async Task OnDisconnectedAsync(ConnectionContext connection)
|
public override async Task OnDisconnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection);
|
await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection);
|
||||||
_connections.Remove(connection);
|
_connections.Remove(connection);
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ using System.Net;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.AspNetCore.SignalR;
|
||||||
using Microsoft.AspNetCore.SignalR.Redis;
|
using Microsoft.AspNetCore.SignalR.Redis;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Newtonsoft.Json;
|
using Newtonsoft.Json;
|
||||||
|
|
@ -129,7 +129,7 @@ namespace ChatSample
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task AddUser(ConnectionContext connection, UserDetails userDetails)
|
public async Task AddUser(HubConnectionContext connection, UserDetails userDetails)
|
||||||
{
|
{
|
||||||
var key = GetUserRedisKey(connection);
|
var key = GetUserRedisKey(connection);
|
||||||
var user = SerializeUser(connection);
|
var user = SerializeUser(connection);
|
||||||
|
|
@ -156,7 +156,7 @@ namespace ChatSample
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task RemoveUser(ConnectionContext connection)
|
public async Task RemoveUser(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
await _userSyncSempaphore.WaitAsync();
|
await _userSyncSempaphore.WaitAsync();
|
||||||
try
|
try
|
||||||
|
|
@ -180,7 +180,7 @@ namespace ChatSample
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}";
|
private static string GetUserRedisKey(HubConnectionContext connection) => $"user:{connection.ConnectionId}";
|
||||||
|
|
||||||
private static void Scan(object state)
|
private static void Scan(object state)
|
||||||
{
|
{
|
||||||
|
|
@ -319,7 +319,7 @@ namespace ChatSample
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static string SerializeUser(ConnectionContext connection) =>
|
private static string SerializeUser(HubConnectionContext connection) =>
|
||||||
$"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}";
|
$"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}";
|
||||||
|
|
||||||
private static UserDetails DeserializerUser(string userJson) =>
|
private static UserDetails DeserializerUser(string userJson) =>
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ using System.Text;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Newtonsoft.Json;
|
using Newtonsoft.Json;
|
||||||
|
|
@ -22,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
{
|
{
|
||||||
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
|
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
|
||||||
|
|
||||||
private readonly ConnectionList _connections = new ConnectionList();
|
private readonly HubConnectionList _connections = new HubConnectionList();
|
||||||
// TODO: Investigate "memory leak" entries never get removed
|
// TODO: Investigate "memory leak" entries never get removed
|
||||||
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
|
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
|
||||||
private readonly ConnectionMultiplexer _redisServerConnection;
|
private readonly ConnectionMultiplexer _redisServerConnection;
|
||||||
|
|
@ -128,7 +127,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
await _bus.PublishAsync(channel, payload);
|
await _bus.PublishAsync(channel, payload);
|
||||||
}
|
}
|
||||||
|
|
||||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
public override Task OnConnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
|
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
|
||||||
var connectionTask = Task.CompletedTask;
|
var connectionTask = Task.CompletedTask;
|
||||||
|
|
@ -173,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
return Task.WhenAll(connectionTask, userTask);
|
return Task.WhenAll(connectionTask, userTask);
|
||||||
}
|
}
|
||||||
|
|
||||||
public override Task OnDisconnectedAsync(ConnectionContext connection)
|
public override Task OnDisconnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
_connections.Remove(connection);
|
_connections.Remove(connection);
|
||||||
|
|
||||||
|
|
@ -307,14 +306,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
_redisServerConnection.Dispose();
|
_redisServerConnection.Dispose();
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
|
private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage)
|
||||||
{
|
{
|
||||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
var data = connection.Protocol.WriteToArray(hubMessage);
|
||||||
var data = protocol.WriteToArray(hubMessage);
|
|
||||||
|
|
||||||
while (await connection.Transport.Out.WaitToWriteAsync())
|
while (await connection.Output.WaitToWriteAsync())
|
||||||
{
|
{
|
||||||
if (connection.Transport.Out.TryWrite(data))
|
if (connection.Output.TryWrite(data))
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -363,7 +361,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
private class GroupData
|
private class GroupData
|
||||||
{
|
{
|
||||||
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
|
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
|
||||||
public ConnectionList Connections = new ConnectionList();
|
public HubConnectionList Connections = new HubConnectionList();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,16 @@
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.IO;
|
|
||||||
using System.Text;
|
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
{
|
{
|
||||||
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
||||||
{
|
{
|
||||||
private long _nextInvocationId = 0;
|
private long _nextInvocationId = 0;
|
||||||
private readonly ConnectionList _connections = new ConnectionList();
|
private readonly HubConnectionList _connections = new HubConnectionList();
|
||||||
|
|
||||||
public override Task AddGroupAsync(string connectionId, string groupName)
|
public override Task AddGroupAsync(string connectionId, string groupName)
|
||||||
{
|
{
|
||||||
|
|
@ -62,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
return InvokeAllWhere(methodName, args, c => true);
|
return InvokeAllWhere(methodName, args, c => true);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Task InvokeAllWhere(string methodName, object[] args, Func<ConnectionContext, bool> include)
|
private Task InvokeAllWhere(string methodName, object[] args, Func<HubConnectionContext, bool> include)
|
||||||
{
|
{
|
||||||
var tasks = new List<Task>(_connections.Count);
|
var tasks = new List<Task>(_connections.Count);
|
||||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||||
|
|
@ -107,26 +104,25 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
public override Task OnConnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
_connections.Add(connection);
|
_connections.Add(connection);
|
||||||
return Task.CompletedTask;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
||||||
public override Task OnDisconnectedAsync(ConnectionContext connection)
|
public override Task OnDisconnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
_connections.Remove(connection);
|
_connections.Remove(connection);
|
||||||
return Task.CompletedTask;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
|
private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage)
|
||||||
{
|
{
|
||||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
var payload = connection.Protocol.WriteToArray(hubMessage);
|
||||||
var payload = protocol.WriteToArray(hubMessage);
|
|
||||||
|
|
||||||
while (await connection.Transport.Out.WaitToWriteAsync())
|
while (await connection.Output.WaitToWriteAsync())
|
||||||
{
|
{
|
||||||
if (connection.Transport.Out.TryWrite(payload))
|
if (connection.Output.TryWrite(payload))
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,17 @@
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using System.Security.Claims;
|
using System.Security.Claims;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
{
|
{
|
||||||
public class HubCallerContext
|
public class HubCallerContext
|
||||||
{
|
{
|
||||||
public HubCallerContext(ConnectionContext connection)
|
public HubCallerContext(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
Connection = connection;
|
Connection = connection;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ConnectionContext Connection { get; }
|
public HubConnectionContext Connection { get; }
|
||||||
|
|
||||||
public ClaimsPrincipal User => Connection.User;
|
public ClaimsPrincipal User => Connection.User;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Security.Claims;
|
||||||
|
using System.Threading.Tasks.Channels;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
|
using Microsoft.AspNetCore.Sockets;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
|
{
|
||||||
|
public class HubConnectionContext
|
||||||
|
{
|
||||||
|
private readonly WritableChannel<byte[]> _output;
|
||||||
|
private readonly ConnectionContext _connectionContext;
|
||||||
|
|
||||||
|
public HubConnectionContext(WritableChannel<byte[]> output, ConnectionContext connectionContext)
|
||||||
|
{
|
||||||
|
_output = output;
|
||||||
|
_connectionContext = connectionContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by the HubEndPoint only
|
||||||
|
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
|
||||||
|
|
||||||
|
public virtual string ConnectionId => _connectionContext.ConnectionId;
|
||||||
|
|
||||||
|
public virtual ClaimsPrincipal User => _connectionContext.User;
|
||||||
|
|
||||||
|
public virtual ConnectionMetadata Metadata => _connectionContext.Metadata;
|
||||||
|
|
||||||
|
public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||||
|
|
||||||
|
public virtual WritableChannel<byte[]> Output => _output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections;
|
||||||
|
using System.Collections.Concurrent;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
|
{
|
||||||
|
public class HubConnectionList : IReadOnlyCollection<HubConnectionContext>
|
||||||
|
{
|
||||||
|
private readonly ConcurrentDictionary<string, HubConnectionContext> _connections = new ConcurrentDictionary<string, HubConnectionContext>();
|
||||||
|
|
||||||
|
public HubConnectionContext this[string connectionId]
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
if (_connections.TryGetValue(connectionId, out var connection))
|
||||||
|
{
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int Count => _connections.Count;
|
||||||
|
|
||||||
|
public void Add(HubConnectionContext connection)
|
||||||
|
{
|
||||||
|
_connections.TryAdd(connection.ConnectionId, connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Remove(HubConnectionContext connection)
|
||||||
|
{
|
||||||
|
_connections.TryRemove(connection.ConnectionId, out _);
|
||||||
|
}
|
||||||
|
|
||||||
|
public IEnumerator<HubConnectionContext> GetEnumerator()
|
||||||
|
{
|
||||||
|
foreach (var item in _connections)
|
||||||
|
{
|
||||||
|
yield return item.Value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IEnumerator IEnumerable.GetEnumerator()
|
||||||
|
{
|
||||||
|
return GetEnumerator();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -59,24 +59,54 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public async Task OnConnectedAsync(ConnectionContext connection)
|
public async Task OnConnectedAsync(ConnectionContext connection)
|
||||||
{
|
{
|
||||||
await ProcessNegotiate(connection);
|
var output = Channel.CreateUnbounded<byte[]>();
|
||||||
|
var connectionContext = new HubConnectionContext(output, connection);
|
||||||
|
|
||||||
|
await ProcessNegotiate(connectionContext);
|
||||||
|
|
||||||
|
// Hubs support multiple producers so we set up this loop to copy
|
||||||
|
// data written to the HubConnectionContext's channel to the transport channel
|
||||||
|
async Task WriteToTransport()
|
||||||
|
{
|
||||||
|
while (await output.In.WaitToReadAsync())
|
||||||
|
{
|
||||||
|
while (output.In.TryRead(out var buffer))
|
||||||
|
{
|
||||||
|
while (await connection.Transport.Out.WaitToWriteAsync())
|
||||||
|
{
|
||||||
|
if (connection.Transport.Out.TryWrite(buffer))
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var writingOutputTask = WriteToTransport();
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await _lifetimeManager.OnConnectedAsync(connection);
|
await _lifetimeManager.OnConnectedAsync(connectionContext);
|
||||||
await RunHubAsync(connection);
|
await RunHubAsync(connectionContext);
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
{
|
{
|
||||||
await _lifetimeManager.OnDisconnectedAsync(connection);
|
await _lifetimeManager.OnDisconnectedAsync(connectionContext);
|
||||||
|
|
||||||
|
// Nothing should be writing to the HubConnectionContext
|
||||||
|
output.Out.TryComplete();
|
||||||
|
|
||||||
|
// This should unwind once we complete the output
|
||||||
|
await writingOutputTask;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ProcessNegotiate(ConnectionContext connection)
|
private async Task ProcessNegotiate(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
while (await connection.Transport.In.WaitToReadAsync())
|
while (await connection.Input.WaitToReadAsync())
|
||||||
{
|
{
|
||||||
while (connection.Transport.In.TryRead(out var buffer))
|
while (connection.Input.TryRead(out var buffer))
|
||||||
{
|
{
|
||||||
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage))
|
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage))
|
||||||
{
|
{
|
||||||
|
|
@ -92,7 +122,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task RunHubAsync(ConnectionContext connection)
|
private async Task RunHubAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
await HubOnConnectedAsync(connection);
|
await HubOnConnectedAsync(connection);
|
||||||
|
|
||||||
|
|
@ -110,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
await HubOnDisconnectedAsync(connection, null);
|
await HubOnDisconnectedAsync(connection, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task HubOnConnectedAsync(ConnectionContext connection)
|
private async Task HubOnConnectedAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
|
@ -136,7 +166,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception)
|
private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Exception exception)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
|
@ -162,7 +192,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task DispatchMessagesAsync(ConnectionContext connection)
|
private async Task DispatchMessagesAsync(HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
// We use these for error handling. Since we dispatch multiple hub invocations
|
// We use these for error handling. Since we dispatch multiple hub invocations
|
||||||
// in parallel, we need a way to communicate failure back to the main processing loop. The
|
// in parallel, we need a way to communicate failure back to the main processing loop. The
|
||||||
|
|
@ -174,9 +204,9 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
while (await connection.Transport.In.WaitToReadAsync(cts.Token))
|
while (await connection.Input.WaitToReadAsync(cts.Token))
|
||||||
{
|
{
|
||||||
while (connection.Transport.In.TryRead(out var buffer))
|
while (connection.Input.TryRead(out var buffer))
|
||||||
{
|
{
|
||||||
if (protocol.TryParseMessages(buffer, this, out var hubMessages))
|
if (protocol.TryParseMessages(buffer, this, out var hubMessages))
|
||||||
{
|
{
|
||||||
|
|
@ -212,7 +242,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ProcessInvocation(ConnectionContext connection,
|
private async Task ProcessInvocation(HubConnectionContext connection,
|
||||||
IHubProtocol protocol,
|
IHubProtocol protocol,
|
||||||
InvocationMessage invocationMessage,
|
InvocationMessage invocationMessage,
|
||||||
CancellationTokenSource dispatcherCancellation,
|
CancellationTokenSource dispatcherCancellation,
|
||||||
|
|
@ -234,7 +264,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
private async Task Execute(HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
||||||
{
|
{
|
||||||
if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor))
|
if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor))
|
||||||
{
|
{
|
||||||
|
|
@ -248,13 +278,13 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
|
private async Task SendMessageAsync(HubConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
|
||||||
{
|
{
|
||||||
var payload = protocol.WriteToArray(hubMessage);
|
var payload = protocol.WriteToArray(hubMessage);
|
||||||
|
|
||||||
while (await connection.Transport.Out.WaitToWriteAsync())
|
while (await connection.Output.WaitToWriteAsync())
|
||||||
{
|
{
|
||||||
if (connection.Transport.Out.TryWrite(payload))
|
if (connection.Output.TryWrite(payload))
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -265,7 +295,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
|
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
||||||
{
|
{
|
||||||
var methodExecutor = descriptor.MethodExecutor;
|
var methodExecutor = descriptor.MethodExecutor;
|
||||||
|
|
||||||
|
|
@ -341,7 +371,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void InitializeHub(THub hub, ConnectionContext connection)
|
private void InitializeHub(THub hub, HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
hub.Clients = _hubContext.Clients;
|
hub.Clients = _hubContext.Clients;
|
||||||
hub.Context = new HubCallerContext(connection);
|
hub.Context = new HubCallerContext(connection);
|
||||||
|
|
@ -363,7 +393,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task StreamResultsAsync(string invocationId, ConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator<object> enumerator)
|
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator<object> enumerator)
|
||||||
{
|
{
|
||||||
// TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481
|
// TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481
|
||||||
try
|
try
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,14 @@
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
{
|
{
|
||||||
public abstract class HubLifetimeManager<THub>
|
public abstract class HubLifetimeManager<THub>
|
||||||
{
|
{
|
||||||
public abstract Task OnConnectedAsync(ConnectionContext connection);
|
public abstract Task OnConnectedAsync(HubConnectionContext connection);
|
||||||
|
|
||||||
public abstract Task OnDisconnectedAsync(ConnectionContext connection);
|
public abstract Task OnDisconnectedAsync(HubConnectionContext connection);
|
||||||
|
|
||||||
public abstract Task InvokeAllAsync(string methodName, object[] args);
|
public abstract Task InvokeAllAsync(string methodName, object[] args);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
{
|
{
|
||||||
public class DefaultHubProtocolResolver : IHubProtocolResolver
|
public class DefaultHubProtocolResolver : IHubProtocolResolver
|
||||||
{
|
{
|
||||||
public IHubProtocol GetProtocol(string protocolName, ConnectionContext connection)
|
public IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection)
|
||||||
{
|
{
|
||||||
switch (protocolName?.ToLowerInvariant())
|
switch (protocolName?.ToLowerInvariant())
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,11 @@
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
{
|
{
|
||||||
public interface IHubProtocolResolver
|
public interface IHubProtocolResolver
|
||||||
{
|
{
|
||||||
IHubProtocol GetProtocol(string protocolName, ConnectionContext connection);
|
IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
{
|
{
|
||||||
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
|
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
|
||||||
mockLifetimeManager
|
mockLifetimeManager
|
||||||
.Setup(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()))
|
.Setup(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()))
|
||||||
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
|
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
|
||||||
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
|
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
|
||||||
|
|
||||||
|
|
@ -64,8 +64,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
|
||||||
client.Dispose();
|
client.Dispose();
|
||||||
|
|
||||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
// No hubs should be created since the connection is terminated
|
// No hubs should be created since the connection is terminated
|
||||||
mockHubActivator.Verify(m => m.Create(), Times.Never);
|
mockHubActivator.Verify(m => m.Create(), Times.Never);
|
||||||
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
|
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
|
||||||
|
|
@ -91,8 +91,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
||||||
Assert.Equal("Hub OnConnected failed.", exception.Message);
|
Assert.Equal("Hub OnConnected failed.", exception.Message);
|
||||||
|
|
||||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,8 +115,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
||||||
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
|
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
|
||||||
|
|
||||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<HubConnectionContext>()), Times.Once);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
|
using System.Threading.Tasks.Channels;
|
||||||
using Microsoft.AspNetCore.SignalR.Internal;
|
using Microsoft.AspNetCore.SignalR.Internal;
|
||||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
using Microsoft.AspNetCore.Sockets;
|
using Microsoft.AspNetCore.Sockets;
|
||||||
|
|
@ -18,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
|
||||||
[MemberData(nameof(HubProtocols))]
|
[MemberData(nameof(HubProtocols))]
|
||||||
public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol)
|
public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol)
|
||||||
{
|
{
|
||||||
var mockConnection = new Mock<ConnectionContext>();
|
var mockConnection = new Mock<HubConnectionContext>(Channel.CreateUnbounded<byte[]>().Out, new Mock<ConnectionContext>().Object);
|
||||||
Assert.IsType(
|
Assert.IsType(
|
||||||
protocol.GetType(),
|
protocol.GetType(),
|
||||||
new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object));
|
new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object));
|
||||||
|
|
@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
|
||||||
[InlineData("dummy")]
|
[InlineData("dummy")]
|
||||||
public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName)
|
public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName)
|
||||||
{
|
{
|
||||||
var mockConnection = new Mock<ConnectionContext>();
|
var mockConnection = new Mock<HubConnectionContext>(Channel.CreateUnbounded<byte[]>().Out, new Mock<ConnectionContext>().Object);
|
||||||
var exception = Assert.Throws<NotSupportedException>(
|
var exception = Assert.Throws<NotSupportedException>(
|
||||||
() => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object));
|
() => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue