diff --git a/samples/ChatSample/IUserTracker.cs b/samples/ChatSample/IUserTracker.cs index ca46483f72..9e7da32f97 100644 --- a/samples/ChatSample/IUserTracker.cs +++ b/samples/ChatSample/IUserTracker.cs @@ -4,15 +4,15 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.SignalR; namespace ChatSample { public interface IUserTracker { Task> UsersOnline(); - Task AddUser(ConnectionContext connection, UserDetails userDetails); - Task RemoveUser(ConnectionContext connection); + Task AddUser(HubConnectionContext connection, UserDetails userDetails); + Task RemoveUser(HubConnectionContext connection); event Action UsersJoined; event Action UsersLeft; diff --git a/samples/ChatSample/InMemoryUserTracker.cs b/samples/ChatSample/InMemoryUserTracker.cs index 27ce19d717..484186e47a 100644 --- a/samples/ChatSample/InMemoryUserTracker.cs +++ b/samples/ChatSample/InMemoryUserTracker.cs @@ -3,14 +3,14 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.SignalR; namespace ChatSample { public class InMemoryUserTracker : IUserTracker { - private readonly ConcurrentDictionary _usersOnline - = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _usersOnline + = new ConcurrentDictionary(); public event Action UsersJoined; public event Action UsersLeft; @@ -18,7 +18,7 @@ namespace ChatSample public Task> UsersOnline() => Task.FromResult(_usersOnline.Values.AsEnumerable()); - public Task AddUser(ConnectionContext connection, UserDetails userDetails) + public Task AddUser(HubConnectionContext connection, UserDetails userDetails) { _usersOnline.TryAdd(connection, userDetails); UsersJoined(new[] { userDetails }); @@ -26,7 +26,7 @@ namespace ChatSample return Task.CompletedTask; } - public Task RemoveUser(ConnectionContext connection) + public Task RemoveUser(HubConnectionContext connection) { if (_usersOnline.TryRemove(connection, out var userDetails)) { diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index c45a59e1e7..a1d5bfbc22 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -36,7 +36,7 @@ namespace ChatSample where THubLifetimeManager : HubLifetimeManager where THub : HubWithPresence { - private readonly ConnectionList _connections = new ConnectionList(); + private readonly HubConnectionList _connections = new HubConnectionList(); private readonly IUserTracker _userTracker; private readonly IServiceScopeFactory _serviceScopeFactory; private readonly ILogger _logger; @@ -57,14 +57,14 @@ namespace ChatSample _wrappedHubLifetimeManager = serviceProvider.GetRequiredService(); } - public override async Task OnConnectedAsync(ConnectionContext connection) + public override async Task OnConnectedAsync(HubConnectionContext connection) { await _wrappedHubLifetimeManager.OnConnectedAsync(connection); _connections.Add(connection); await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name)); } - public override async Task OnDisconnectedAsync(ConnectionContext connection) + public override async Task OnDisconnectedAsync(HubConnectionContext connection) { await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection); _connections.Remove(connection); diff --git a/samples/ChatSample/RedisUserTracker.cs b/samples/ChatSample/RedisUserTracker.cs index 06753884ae..0eee499852 100644 --- a/samples/ChatSample/RedisUserTracker.cs +++ b/samples/ChatSample/RedisUserTracker.cs @@ -9,8 +9,8 @@ using System.Net; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Redis; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Newtonsoft.Json; @@ -129,7 +129,7 @@ namespace ChatSample } } - public async Task AddUser(ConnectionContext connection, UserDetails userDetails) + public async Task AddUser(HubConnectionContext connection, UserDetails userDetails) { var key = GetUserRedisKey(connection); var user = SerializeUser(connection); @@ -156,7 +156,7 @@ namespace ChatSample } } - public async Task RemoveUser(ConnectionContext connection) + public async Task RemoveUser(HubConnectionContext connection) { await _userSyncSempaphore.WaitAsync(); try @@ -180,7 +180,7 @@ namespace ChatSample } } - private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}"; + private static string GetUserRedisKey(HubConnectionContext connection) => $"user:{connection.ConnectionId}"; private static void Scan(object state) { @@ -319,7 +319,7 @@ namespace ChatSample } } - private static string SerializeUser(ConnectionContext connection) => + private static string SerializeUser(HubConnectionContext connection) => $"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}"; private static UserDetails DeserializerUser(string userJson) => diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index a276fde7e7..085648ba76 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -10,7 +10,6 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Newtonsoft.Json; @@ -22,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { private const string RedisSubscriptionsMetadataName = "redis_subscriptions"; - private readonly ConnectionList _connections = new ConnectionList(); + private readonly HubConnectionList _connections = new HubConnectionList(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); private readonly ConnectionMultiplexer _redisServerConnection; @@ -128,7 +127,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis await _bus.PublishAsync(channel, payload); } - public override Task OnConnectedAsync(ConnectionContext connection) + public override Task OnConnectedAsync(HubConnectionContext connection) { var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); var connectionTask = Task.CompletedTask; @@ -173,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(connectionTask, userTask); } - public override Task OnDisconnectedAsync(ConnectionContext connection) + public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); @@ -307,14 +306,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis _redisServerConnection.Dispose(); } - private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) + private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage) { - var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); - var data = protocol.WriteToArray(hubMessage); + var data = connection.Protocol.WriteToArray(hubMessage); - while (await connection.Transport.Out.WaitToWriteAsync()) + while (await connection.Output.WaitToWriteAsync()) { - if (connection.Transport.Out.TryWrite(data)) + if (connection.Output.TryWrite(data)) { break; } @@ -363,7 +361,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis private class GroupData { public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); - public ConnectionList Connections = new ConnectionList(); + public HubConnectionList Connections = new HubConnectionList(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 95958f960c..12ec7f3248 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -3,19 +3,16 @@ using System; using System.Collections.Generic; -using System.IO; -using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR { public class DefaultHubLifetimeManager : HubLifetimeManager { private long _nextInvocationId = 0; - private readonly ConnectionList _connections = new ConnectionList(); + private readonly HubConnectionList _connections = new HubConnectionList(); public override Task AddGroupAsync(string connectionId, string groupName) { @@ -62,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR return InvokeAllWhere(methodName, args, c => true); } - private Task InvokeAllWhere(string methodName, object[] args, Func include) + private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); @@ -107,26 +104,25 @@ namespace Microsoft.AspNetCore.SignalR }); } - public override Task OnConnectedAsync(ConnectionContext connection) + public override Task OnConnectedAsync(HubConnectionContext connection) { _connections.Add(connection); return Task.CompletedTask; } - public override Task OnDisconnectedAsync(ConnectionContext connection) + public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); return Task.CompletedTask; } - private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) + private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage) { - var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); - var payload = protocol.WriteToArray(hubMessage); + var payload = connection.Protocol.WriteToArray(hubMessage); - while (await connection.Transport.Out.WaitToWriteAsync()) + while (await connection.Output.WaitToWriteAsync()) { - if (connection.Transport.Out.TryWrite(payload)) + if (connection.Output.TryWrite(payload)) { break; } diff --git a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs index f5e7b50cfa..3f75195538 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs @@ -2,18 +2,17 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Security.Claims; -using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR { public class HubCallerContext { - public HubCallerContext(ConnectionContext connection) + public HubCallerContext(HubConnectionContext connection) { Connection = connection; } - public ConnectionContext Connection { get; } + public HubConnectionContext Connection { get; } public ClaimsPrincipal User => Connection.User; diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs new file mode 100644 index 0000000000..6bd4c5e101 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubConnectionContext + { + private readonly WritableChannel _output; + private readonly ConnectionContext _connectionContext; + + public HubConnectionContext(WritableChannel output, ConnectionContext connectionContext) + { + _output = output; + _connectionContext = connectionContext; + } + + // Used by the HubEndPoint only + internal ReadableChannel Input => _connectionContext.Transport; + + public virtual string ConnectionId => _connectionContext.ConnectionId; + + public virtual ClaimsPrincipal User => _connectionContext.User; + + public virtual ConnectionMetadata Metadata => _connectionContext.Metadata; + + public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + + public virtual WritableChannel Output => _output; + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionList.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionList.cs new file mode 100644 index 0000000000..f57b5ec9ba --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionList.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubConnectionList : IReadOnlyCollection + { + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + + public HubConnectionContext this[string connectionId] + { + get + { + if (_connections.TryGetValue(connectionId, out var connection)) + { + return connection; + } + return null; + } + } + + public int Count => _connections.Count; + + public void Add(HubConnectionContext connection) + { + _connections.TryAdd(connection.ConnectionId, connection); + } + + public void Remove(HubConnectionContext connection) + { + _connections.TryRemove(connection.ConnectionId, out _); + } + + public IEnumerator GetEnumerator() + { + foreach (var item in _connections) + { + yield return item.Value; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index e807efb896..53df3de1e4 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -59,24 +59,54 @@ namespace Microsoft.AspNetCore.SignalR public async Task OnConnectedAsync(ConnectionContext connection) { - await ProcessNegotiate(connection); + var output = Channel.CreateUnbounded(); + var connectionContext = new HubConnectionContext(output, connection); + + await ProcessNegotiate(connectionContext); + + // Hubs support multiple producers so we set up this loop to copy + // data written to the HubConnectionContext's channel to the transport channel + async Task WriteToTransport() + { + while (await output.In.WaitToReadAsync()) + { + while (output.In.TryRead(out var buffer)) + { + while (await connection.Transport.Out.WaitToWriteAsync()) + { + if (connection.Transport.Out.TryWrite(buffer)) + { + break; + } + } + } + } + } + + var writingOutputTask = WriteToTransport(); try { - await _lifetimeManager.OnConnectedAsync(connection); - await RunHubAsync(connection); + await _lifetimeManager.OnConnectedAsync(connectionContext); + await RunHubAsync(connectionContext); } finally { - await _lifetimeManager.OnDisconnectedAsync(connection); + await _lifetimeManager.OnDisconnectedAsync(connectionContext); + + // Nothing should be writing to the HubConnectionContext + output.Out.TryComplete(); + + // This should unwind once we complete the output + await writingOutputTask; } } - private async Task ProcessNegotiate(ConnectionContext connection) + private async Task ProcessNegotiate(HubConnectionContext connection) { - while (await connection.Transport.In.WaitToReadAsync()) + while (await connection.Input.WaitToReadAsync()) { - while (connection.Transport.In.TryRead(out var buffer)) + while (connection.Input.TryRead(out var buffer)) { if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) { @@ -92,7 +122,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task RunHubAsync(ConnectionContext connection) + private async Task RunHubAsync(HubConnectionContext connection) { await HubOnConnectedAsync(connection); @@ -110,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR await HubOnDisconnectedAsync(connection, null); } - private async Task HubOnConnectedAsync(ConnectionContext connection) + private async Task HubOnConnectedAsync(HubConnectionContext connection) { try { @@ -136,7 +166,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception) + private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Exception exception) { try { @@ -162,7 +192,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task DispatchMessagesAsync(ConnectionContext connection) + private async Task DispatchMessagesAsync(HubConnectionContext connection) { // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The @@ -174,9 +204,9 @@ namespace Microsoft.AspNetCore.SignalR try { - while (await connection.Transport.In.WaitToReadAsync(cts.Token)) + while (await connection.Input.WaitToReadAsync(cts.Token)) { - while (connection.Transport.In.TryRead(out var buffer)) + while (connection.Input.TryRead(out var buffer)) { if (protocol.TryParseMessages(buffer, this, out var hubMessages)) { @@ -212,7 +242,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task ProcessInvocation(ConnectionContext connection, + private async Task ProcessInvocation(HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage, CancellationTokenSource dispatcherCancellation, @@ -234,7 +264,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) + private async Task Execute(HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor)) { @@ -248,13 +278,13 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) + private async Task SendMessageAsync(HubConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) { var payload = protocol.WriteToArray(hubMessage); - while (await connection.Transport.Out.WaitToWriteAsync()) + while (await connection.Output.WaitToWriteAsync()) { - if (connection.Transport.Out.TryWrite(payload)) + if (connection.Output.TryWrite(payload)) { return; } @@ -265,7 +295,7 @@ namespace Microsoft.AspNetCore.SignalR throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); } - private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) + private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { var methodExecutor = descriptor.MethodExecutor; @@ -341,7 +371,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private void InitializeHub(THub hub, ConnectionContext connection) + private void InitializeHub(THub hub, HubConnectionContext connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); @@ -363,7 +393,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task StreamResultsAsync(string invocationId, ConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator enumerator) + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator enumerator) { // TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481 try diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 872ea8901c..0879fa3227 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -2,15 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR { public abstract class HubLifetimeManager { - public abstract Task OnConnectedAsync(ConnectionContext connection); + public abstract Task OnConnectedAsync(HubConnectionContext connection); - public abstract Task OnDisconnectedAsync(ConnectionContext connection); + public abstract Task OnDisconnectedAsync(HubConnectionContext connection); public abstract Task InvokeAllAsync(string methodName, object[] args); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs index 2fdc6202ed..1dec7f4cfc 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { - public IHubProtocol GetProtocol(string protocolName, ConnectionContext connection) + public IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection) { switch (protocolName?.ToLowerInvariant()) { diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs index e797528770..29d1c392d8 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs @@ -2,12 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { - IHubProtocol GetProtocol(string protocolName, ConnectionContext connection); + IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionList.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs similarity index 100% rename from src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionList.cs rename to src/Microsoft.AspNetCore.Sockets/ConnectionList.cs diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index e9cf7c6cdb..d35f34a061 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var mockLifetimeManager = new Mock>(); mockLifetimeManager - .Setup(m => m.OnConnectedAsync(It.IsAny())) + .Setup(m => m.OnConnectedAsync(It.IsAny())) .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); var mockHubActivator = new Mock>(); @@ -64,8 +64,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); // No hubs should be created since the connection is terminated mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny()), Times.Never); @@ -91,8 +91,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnConnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } @@ -115,8 +115,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnDisconnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs index e84a5a640d..927b87a19e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -18,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [MemberData(nameof(HubProtocols))] public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) { - var mockConnection = new Mock(); + var mockConnection = new Mock(Channel.CreateUnbounded().Out, new Mock().Object); Assert.IsType( protocol.GetType(), new DefaultHubProtocolResolver().GetProtocol(protocol.Name, mockConnection.Object)); @@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [InlineData("dummy")] public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName) { - var mockConnection = new Mock(); + var mockConnection = new Mock(Channel.CreateUnbounded().Out, new Mock().Object); var exception = Assert.Throws( () => new DefaultHubProtocolResolver().GetProtocol(protocolName, mockConnection.Object));