diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index a2121063e3..46073813f9 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -4,6 +4,9 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging.Abstractions; @@ -17,18 +20,35 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Params(1, 10, 1000)] public int Connections; + [Params("json", "msgpack")] + public string Protocol; + [GlobalSetup] public void GlobalSetup() { _hubLifetimeManager = new DefaultHubLifetimeManager(); - var options = new UnboundedChannelOptions { AllowSynchronousContinuations = true }; + + + IHubProtocol protocol; + + if (Protocol == "json") + { + protocol = new JsonHubProtocol(); + } + else + { + protocol = new MessagePackHubProtocol(); + } + + var encoder = new PassThroughEncoder(); for (var i = 0; i < Connections; ++i) { var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); - - _hubLifetimeManager.OnConnectedAsync(new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance)).Wait(); + var hubConnection = new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance); + hubConnection.ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, encoder); + _hubLifetimeManager.OnConnectedAsync(hubConnection).Wait(); } _hubContext = new HubContext(_hubLifetimeManager); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 6cfb3bcc58..8f969273ee 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -2,15 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO.Pipelines; using System.Net; +using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; using System.Security.Claims; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols; @@ -31,30 +32,26 @@ namespace Microsoft.AspNetCore.SignalR private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); private readonly ConnectionContext _connectionContext; - private readonly Channel _output; private readonly ILogger _logger; private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); private readonly long _keepAliveDuration; - private Task _writingTask = Task.CompletedTask; + private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); + private long _lastSendTimestamp = Stopwatch.GetTimestamp(); - public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory): - this(connectionContext, keepAliveInterval, loggerFactory, Channel.CreateUnbounded()) + public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { - } - - internal HubConnectionContext(ConnectionContext connectionContext, - TimeSpan keepAliveInterval, - ILoggerFactory loggerFactory, - Channel output) - { - _output = output; _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); ConnectionAbortedToken = _connectionAbortedTokenSource.Token; _keepAliveDuration = (int)keepAliveInterval.TotalMilliseconds * (Stopwatch.Frequency / 1000); + + if (Features.Get() == null) + { + Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); + } } public virtual CancellationToken ConnectionAbortedToken { get; } @@ -86,33 +83,21 @@ namespace Microsoft.AspNetCore.SignalR public int? LocalPort => Features.Get()?.LocalPort; - public async Task WriteAsync(HubMessage message, bool throwOnFailure = false) + public virtual async Task WriteAsync(HubMessage message) { - while (await _output.Writer.WaitToWriteAsync()) - { - if (_output.Writer.TryWrite(message)) - { - return; - } - } + await _writeLock.WaitAsync(); - _logger.OutboundChannelClosed(); + var buffer = ProtocolReaderWriter.WriteMessage(message); - if (throwOnFailure) - { - throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); - } + _connectionContext.Transport.Output.Write(buffer); + + Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); + + await _connectionContext.Transport.Output.FlushAsync(CancellationToken.None); + + _writeLock.Release(); } - - public async Task DisposeAsync() - { - // Nothing should be writing to the HubConnectionContext - _output.Writer.TryComplete(); - - // This should unwind once we complete the output - await _writingTask; - } - + public virtual void Abort() { // If we already triggered the token then noop, this isn't thread safe but it's good enough @@ -126,13 +111,6 @@ namespace Microsoft.AspNetCore.SignalR Task.Factory.StartNew(_abortedCallback, this); } - // Hubs support multiple producers so we set up this loop to copy - // data written to the HubConnectionContext's channel to the transport channel - internal Task StartAsync() - { - return _writingTask = StartAsyncCore(); - } - internal async Task NegotiateAsync(TimeSpan timeout, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider) { try @@ -213,35 +191,6 @@ namespace Microsoft.AspNetCore.SignalR return _abortCompletedTcs.Task; } - private async Task StartAsyncCore() - { - Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called"); - - if (Features.Get() == null) - { - Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); - } - - try - { - while (await _output.Reader.WaitToReadAsync()) - { - while (_output.Reader.TryRead(out var hubMessage)) - { - var buffer = ProtocolReaderWriter.WriteMessage(hubMessage); - - await _connectionContext.Transport.Output.WriteAsync(buffer); - - Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); - } - } - } - catch (Exception ex) - { - Abort(ex); - } - } - private void KeepAliveTick() { // Implements the keep-alive tick behavior @@ -257,15 +206,9 @@ namespace Microsoft.AspNetCore.SignalR // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. - if (_output.Writer.TryWrite(PingMessage.Instance)) - { - _logger.SentPing(); - } - else - { - // This isn't necessarily an error, it just indicates that the transport is applying backpressure right now. - _logger.TransportBufferFull(); - } + _logger.SentPing(); + + _ = WriteAsync(PingMessage.Instance); Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 535c52e3f5..16f495981f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -64,9 +64,6 @@ namespace Microsoft.AspNetCore.SignalR return; } - // We don't need to hold this task, it's also held internally and awaited by DisposeAsync. - _ = connectionContext.StartAsync(); - try { await _lifetimeManager.OnConnectedAsync(connectionContext); @@ -75,8 +72,6 @@ namespace Microsoft.AspNetCore.SignalR finally { await _lifetimeManager.OnDisconnectedAsync(connectionContext); - - await connectionContext.DisposeAsync(); } } @@ -277,7 +272,7 @@ namespace Microsoft.AspNetCore.SignalR private Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage) { - return connection.WriteAsync(hubMessage, throwOnFailure: true); + return connection.WriteAsync(hubMessage); } private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs index 3661bf333e..b3c401163f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs @@ -4,3 +4,4 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests.Utils, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Microbenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 45ce0998c9..3d52960770 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -402,7 +402,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis var invocation = message.CreateInvocation(); foreach (var connection in _connections) { - tasks.Add(connection.WriteAsync(invocation)); + try + { + tasks.Add(connection.WriteAsync(invocation)); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } } await Task.WhenAll(tasks); @@ -434,7 +441,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis { if (!excludedIds.Contains(connection.ConnectionId)) { - tasks.Add(connection.WriteAsync(invocation)); + try + { + tasks.Add(connection.WriteAsync(invocation)); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } } } @@ -562,7 +576,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis continue; } - tasks.Add(groupConnection.WriteAsync(invocation)); + try + { + tasks.Add(groupConnection.WriteAsync(invocation)); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } } await Task.WhenAll(tasks); diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index d1f4eea6e1..7f42a38095 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -3,16 +3,11 @@ using System; using System.Collections.Generic; -using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; using Xunit; @@ -68,8 +63,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await AssertMessageAsync(client1); - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); Assert.Null(client2.TryRead()); } } @@ -97,9 +90,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await AssertMessageAsync(client1); Assert.Null(client2.TryRead()); - - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); } } @@ -123,14 +113,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout(); await manager.AddGroupAsync(connection2.ConnectionId, "gunit").OrTimeout(); - var excludedIds = new List{ client2.Connection.ConnectionId }; + var excludedIds = new List { client2.Connection.ConnectionId }; await manager.SendGroupExceptAsync("gunit", "Hello", new object[] { "World" }, excludedIds).OrTimeout(); await AssertMessageAsync(client1); Assert.Null(client2.TryRead()); - - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); } } @@ -150,8 +137,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - await AssertMessageAsync(client); } } @@ -226,8 +211,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await AssertMessageAsync(client1); - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); Assert.Null(client2.TryRead()); } } @@ -307,7 +290,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); Assert.Null(client.TryRead()); } } @@ -377,8 +359,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager2.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - await AssertMessageAsync(client); } } @@ -402,10 +382,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); await AssertMessageAsync(client); - await connection.DisposeAsync().OrTimeout(); Assert.Null(client.TryRead()); } } @@ -433,10 +411,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager2.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - await AssertMessageAsync(client); - await connection.DisposeAsync().OrTimeout(); Assert.Null(client.TryRead()); } } @@ -469,7 +444,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager2.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); Assert.Null(client.TryRead()); } } @@ -496,8 +470,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager1.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - await AssertMessageAsync(client); Assert.Null(client.TryRead()); } @@ -518,10 +490,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var writer = new Mock>(); - writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); - - var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); + var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception()); + var connection = connectionMock.Object; await manager2.OnConnectedAsync(connection).OrTimeout(); @@ -542,10 +513,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var writer = new Mock>(); - writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); - - var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); + var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); + var connection = connectionMock.Object; await manager.OnConnectedAsync(connection).OrTimeout(); @@ -566,10 +536,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests using (var client2 = new TestClient()) { // Force an exception when writing to connection - var writer = new Mock>(); - writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); + var connectionMock = HubConnectionContextUtils.CreateMock(client1.Connection); + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception()); - var connection1 = HubConnectionContextUtils.Create(client1.Connection, new MockChannel(writer.Object)); + var connection1 = connectionMock.Object; var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs index e90a65510d..25f6327627 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs @@ -2,34 +2,32 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging.Abstractions; +using Moq; namespace Microsoft.AspNetCore.SignalR.Tests { public static class HubConnectionContextUtils { - public static HubConnectionContext Create(DefaultConnectionContext connection, Channel replacementOutput = null) + public static HubConnectionContext Create(DefaultConnectionContext connection) { - HubConnectionContext context = null; - if (replacementOutput != null) + return new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance) { - context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance, replacementOutput); - } - else - { - context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance); - } + ProtocolReaderWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder()) + }; + } - context.ProtocolReaderWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder()); + public static Mock CreateMock(DefaultConnectionContext connection) + { + var mock = new Mock(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance) { CallBase = true }; + var readerWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder()); + mock.SetupGet(m => m.ProtocolReaderWriter).Returns(readerWriter); + return mock; - _ = context.StartAsync(); - - return context; } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs index 226104230e..65a333e142 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs @@ -17,11 +17,13 @@ namespace System.Threading.Tasks public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + var cts = new CancellationTokenSource(); + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } + cts.Cancel(); await task; } @@ -33,11 +35,13 @@ namespace System.Threading.Tasks public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + var cts = new CancellationTokenSource(); + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } + cts.Cancel(); return await task; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index da0ddcc4c1..f298fc24eb 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -25,9 +25,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout(); - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); - var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -57,9 +54,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout(); - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); - var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -86,9 +80,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.SendGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); - await connection1.DisposeAsync().OrTimeout(); - await connection2.DisposeAsync().OrTimeout(); - var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -110,8 +101,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - var message = Assert.IsType(client.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -125,11 +114,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var writer = new Mock>(); - writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); - var manager = new DefaultHubLifetimeManager(); - var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); + + var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); + var connection = connectionMock.Object; await manager.OnConnectedAsync(connection).OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 2e84b84ca7..21e6ff20c3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -1645,32 +1645,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - [Fact] - public async Task ConnectionClosedIfWritingToTransportFails() - { - // MessagePack does not support serializing objects or private types (including anonymous types) - // and throws. In this test we make sure that this exception closes the connection and bubbles up. - - var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(); - - var endPoint = serviceProvider.GetService>(); - - using (var client = new TestClient(false, new MessagePackHubProtocol())) - { - var transportFeature = new Mock(); - transportFeature.SetupGet(f => f.TransportCapabilities).Returns(TransferMode.Binary); - client.Connection.Features.Set(transportFeature.Object); - - var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); - - await client.Connected.OrThrowIfOtherFails(endPointLifetime).OrTimeout(); - - await client.SendInvocationAsync(nameof(MethodHub.SendAnonymousObject)).OrTimeout(); - - await Assert.ThrowsAsync(() => endPointLifetime.OrTimeout()); - } - } - [Fact] public async Task AcceptsPingMessages() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs index 61910660a3..ed825d32fa 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -3,10 +3,10 @@ using System; using System.Collections.Generic; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; @@ -26,7 +26,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [MemberData(nameof(HubProtocols))] public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) { - var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); + + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); Assert.IsType( protocol.GetType(), @@ -36,7 +39,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [Fact] public void DefaultHubProtocolResolverThrowsForNullProtocol() { - var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); var exception = Assert.Throws( () => resolver.GetProtocol(null, mockConnection.Object)); @@ -47,7 +52,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [Fact] public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol() { - var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); var exception = Assert.Throws( () => resolver.GetProtocol("notARealProtocol", mockConnection.Object)); @@ -58,7 +65,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [Fact] public void RegisteringMultipleHubProtocolsFails() { - var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; var exception = Assert.Throws(() => new DefaultHubProtocolResolver(Options.Create(new HubOptions()), new[] { new JsonHubProtocol(), new JsonHubProtocol()