From 079a56be1a2e9d5e293a685d53a2a1d106ef29e8 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Fri, 16 Mar 2018 16:08:11 -0700 Subject: [PATCH 1/3] Small optimizations (#1617) - Return ValueTask instead of Task from WriteAsync helpers - Use TryGet instead of foreach to avoid enumerator (though it's just a stack allocation here) --- src/Common/PipeWriterStream.cs | 2 +- src/Common/StreamExtensions.cs | 23 ++++++++++-- src/Common/WebSocketExtensions.cs | 36 ++++++++++++++----- .../Internal/StreamExtensions.cs | 17 ++------- .../SendUtils.cs | 2 +- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/Common/PipeWriterStream.cs b/src/Common/PipeWriterStream.cs index 56472af59e..df6f85aaa1 100644 --- a/src/Common/PipeWriterStream.cs +++ b/src/Common/PipeWriterStream.cs @@ -67,7 +67,7 @@ namespace System.IO.Pipelines { _pipeWriter.Write(source.Span); _length += source.Length; - return new ValueTask(Task.CompletedTask); + return default; } #endif } diff --git a/src/Common/StreamExtensions.cs b/src/Common/StreamExtensions.cs index ad801470c2..60da892475 100644 --- a/src/Common/StreamExtensions.cs +++ b/src/Common/StreamExtensions.cs @@ -11,10 +11,27 @@ namespace System.IO { internal static class StreamExtensions { - public static async Task WriteAsync(this Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken = default) + public static ValueTask WriteAsync(this Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken = default) { - // REVIEW: Should we special case IsSingleSegment here? - foreach (var segment in buffer) + if (buffer.IsSingleSegment) + { +#if NETCOREAPP2_1 + return stream.WriteAsync(buffer.First, cancellationToken); +#else + var isArray = MemoryMarshal.TryGetArray(buffer.First, out var arraySegment); + // We're using the managed memory pool which is backed by managed buffers + Debug.Assert(isArray); + return new ValueTask(stream.WriteAsync(arraySegment.Array, arraySegment.Offset, arraySegment.Count, cancellationToken)); +#endif + } + + return WriteMultiSegmentAsync(stream, buffer, cancellationToken); + } + + private static async ValueTask WriteMultiSegmentAsync(Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken) + { + var position = buffer.Start; + while (buffer.TryGet(ref position, out var segment)) { #if NETCOREAPP2_1 await stream.WriteAsync(segment, cancellationToken); diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs index 06094a8df8..8e3d4feb50 100644 --- a/src/Common/WebSocketExtensions.cs +++ b/src/Common/WebSocketExtensions.cs @@ -1,12 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -14,29 +11,50 @@ namespace System.Net.WebSockets { internal static class WebSocketExtensions { - public static Task SendAsync(this WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + public static ValueTask SendAsync(this WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) { - // TODO: Consider chunking writes here if we get a multi segment buffer #if NETCOREAPP2_1 if (buffer.IsSingleSegment) { - return webSocket.SendAsync(buffer.First, webSocketMessageType, endOfMessage: true, cancellationToken).AsTask(); + return webSocket.SendAsync(buffer.First, webSocketMessageType, endOfMessage: true, cancellationToken); } else { - return webSocket.SendAsync(buffer.ToArray(), webSocketMessageType, endOfMessage: true, cancellationToken); + return SendMultiSegmentAsync(webSocket, buffer, webSocketMessageType, cancellationToken); } #else if (buffer.IsSingleSegment) { var isArray = MemoryMarshal.TryGetArray(buffer.First, out var segment); Debug.Assert(isArray); - return webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: true, cancellationToken); + return new ValueTask(webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: true, cancellationToken)); } else { - return webSocket.SendAsync(new ArraySegment(buffer.ToArray()), webSocketMessageType, true, cancellationToken); + return SendMultiSegmentAsync(webSocket, buffer, webSocketMessageType, cancellationToken); } +#endif + } + + private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + { + var position = buffer.Start; + while (buffer.TryGet(ref position, out var segment)) + { +#if NETCOREAPP2_1 + await webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: false, cancellationToken); +#else + var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment); + Debug.Assert(isArray); + await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken); +#endif + } + + // Empty end of message frame +#if NETCOREAPP2_1 + await webSocket.SendAsync(Memory.Empty, webSocketMessageType, endOfMessage: true, cancellationToken); +#else + await webSocket.SendAsync(new ArraySegment(Array.Empty()), webSocketMessageType, endOfMessage: true, cancellationToken); #endif } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs index 8adf707c49..6b3df653d9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs @@ -12,7 +12,9 @@ namespace System.IO.Pipelines { try { - await stream.CopyToAsync(writer, cancellationToken); + // REVIEW: Should we use the default buffer size here? + // 81920 is the default bufferSize, there is no stream.CopyToAsync overload that takes only a cancellationToken + await stream.CopyToAsync(new PipelineWriterStream(writer), bufferSize: 81920, cancellationToken: cancellationToken); } catch (Exception ex) { @@ -22,19 +24,6 @@ namespace System.IO.Pipelines writer.Complete(); } - /// - /// Copies the content of a into a . - /// - /// - /// - /// - /// - private static Task CopyToAsync(this Stream stream, PipeWriter writer, CancellationToken cancellationToken = default) - { - // 81920 is the default bufferSize, there is not stream.CopyToAsync overload that takes only a cancellationToken - return stream.CopyToAsync(new PipelineWriterStream(writer), bufferSize: 81920, cancellationToken: cancellationToken); - } - private class PipelineWriterStream : Stream { private readonly PipeWriter _writer; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index d113feec6e..b8b8e31060 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -111,7 +111,7 @@ namespace Microsoft.AspNetCore.Sockets.Client protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) { - return stream.WriteAsync(_buffer); + return stream.WriteAsync(_buffer).AsTask(); } protected override bool TryComputeLength(out long length) From 0e38ee3e63f4fb7ed08b3626fa85ec5f33952e09 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 16 Mar 2018 16:16:34 -0700 Subject: [PATCH 2/3] Create connectionIds using RNGCrypto (#1606) --- .../DefaultHubDispatcherBenchmark.cs | 5 ++- build/dependencies.props | 1 + .../DefaultConnectionContext.cs | 32 +++++++++++++---- .../ConnectionManager.cs | 30 ++++++++++++---- .../HttpConnectionDispatcher.cs | 35 ++++++++++++------- .../Microsoft.AspNetCore.Sockets.Http.csproj | 3 +- .../ConnectionManagerTests.cs | 23 ++++++------ .../HttpConnectionDispatcherTests.cs | 3 ++ .../LongPollingTests.cs | 1 + .../ServerSentEventsTests.cs | 2 -- .../WebSocketsTests.cs | 1 + 11 files changed, 94 insertions(+), 42 deletions(-) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 5bc73b8aba..5799673af7 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -39,9 +39,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks new HubContext(new DefaultHubLifetimeManager()), new Logger>(NullLoggerFactory.Instance)); - var options = new PipeOptions(); - var pair = DuplexPipe.CreateConnectionPair(options, options); - var connection = new Sockets.DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new Sockets.DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); _connectionContext = new NoErrorHubConnectionContext(connection, TimeSpan.Zero, NullLoggerFactory.Instance); diff --git a/build/dependencies.props b/build/dependencies.props index 2af21b499f..84baeb32c6 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -51,6 +51,7 @@ 2.1.0-preview2-30355 2.1.0-preview2-30355 2.1.0-preview2-30355 + 2.1.0-preview2-30355 2.1.0-preview2-30355 2.0.0 2.1.0-preview2-26314-02 diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs index f208da2eae..693fb74312 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs @@ -22,16 +22,20 @@ namespace Microsoft.AspNetCore.Sockets IConnectionHeartbeatFeature, ITransferFormatFeature { - private List<(Action handler, object state)> _heartbeatHandlers = new List<(Action handler, object state)>(); + private object _heartbeatLock = new object(); + private List<(Action handler, object state)> _heartbeatHandlers; // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); - public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) + /// + /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. + /// The caller is expected to set the and pipes manually. + /// + /// + public DefaultConnectionContext(string id) { - Transport = transport; - Application = application; ConnectionId = id; LastSeenUtc = DateTime.UtcNow; @@ -50,6 +54,13 @@ namespace Microsoft.AspNetCore.Sockets Features.Set(this); } + public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) + : this(id) + { + Transport = transport; + Application = application; + } + public CancellationTokenSource Cancellation { get; set; } public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1); @@ -80,16 +91,25 @@ namespace Microsoft.AspNetCore.Sockets public void OnHeartbeat(Action action, object state) { - lock (_heartbeatHandlers) + lock (_heartbeatLock) { + if (_heartbeatHandlers == null) + { + _heartbeatHandlers = new List<(Action handler, object state)>(); + } _heartbeatHandlers.Add((action, state)); } } public void TickHeartbeat() { - lock (_heartbeatHandlers) + lock (_heartbeatLock) { + if (_heartbeatHandlers == null) + { + return; + } + foreach (var (handler, state) in _heartbeatHandlers) { handler(state); diff --git a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionManager.cs index b04d8e674e..ada2f4ba64 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionManager.cs @@ -2,12 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers.Text; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.IO.Pipelines; using System.Net.WebSockets; +using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; @@ -22,6 +24,8 @@ namespace Microsoft.AspNetCore.Sockets // TODO: Consider making this configurable? At least for testing? private static readonly TimeSpan _heartbeatTickRate = TimeSpan.FromSeconds(1); + private static readonly RNGCryptoServiceProvider _keyGenerator = new RNGCryptoServiceProvider(); + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; private readonly ILogger _logger; @@ -63,23 +67,31 @@ namespace Microsoft.AspNetCore.Sockets return false; } - public DefaultConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions) + /// + /// Creates a connection without Pipes setup to allow saving allocations until Pipes are needed. + /// + /// + public DefaultConnectionContext CreateConnection() { var id = MakeNewConnectionId(); _logger.CreatedNewConnection(id); var connectionTimer = SocketEventSource.Log.ConnectionStart(id); - var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); - var connection = new DefaultConnectionContext(id, pair.Application, pair.Transport); + var connection = new DefaultConnectionContext(id); _connections.TryAdd(id, (connection, connectionTimer)); return connection; } - public DefaultConnectionContext CreateConnection() + public DefaultConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions) { - return CreateConnection(PipeOptions.Default, PipeOptions.Default); + var connection = CreateConnection(); + var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); + connection.Application = pair.Transport; + connection.Transport = pair.Application; + + return connection; } public void RemoveConnection(string id) @@ -94,8 +106,12 @@ namespace Microsoft.AspNetCore.Sockets private static string MakeNewConnectionId() { - // TODO: We need to sign and encyrpt this - return Guid.NewGuid().ToString(); + // TODO: Use Span when WebEncoders implements Span methods https://github.com/aspnet/Home/issues/2966 + // 128 bit buffer / 8 bits per byte = 16 bytes + var buffer = new byte[16]; + _keyGenerator.GetBytes(buffer); + // Generate the id with RNGCrypto because we want a cryptographically random id, which GUID is not + return WebEncoders.Base64UrlEncode(buffer); } private static void Scan(object state) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index fe621b639f..d70b7127e9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.Sockets if (HttpMethods.IsPost(context.Request.Method)) { // POST /{path} - await ProcessSend(context); + await ProcessSend(context, options); } else if (HttpMethods.IsGet(context.Request.Method)) { @@ -99,7 +99,7 @@ namespace Microsoft.AspNetCore.Sockets if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true) { // Connection must already exist - var connection = await GetConnectionAsync(context); + var connection = await GetConnectionAsync(context, options); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -149,7 +149,7 @@ namespace Microsoft.AspNetCore.Sockets // GET /{path} maps to long polling // Connection must already exist - var connection = await GetConnectionAsync(context); + var connection = await GetConnectionAsync(context, options); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -361,7 +361,7 @@ namespace Microsoft.AspNetCore.Sockets context.Response.ContentType = "application/json"; // Establish the connection - var connection = CreateConnectionInternal(options); + var connection = _manager.CreateConnection(); // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. @@ -429,9 +429,9 @@ namespace Microsoft.AspNetCore.Sockets private static string GetConnectionId(HttpContext context) => context.Request.Query["id"]; - private async Task ProcessSend(HttpContext context) + private async Task ProcessSend(HttpContext context, HttpSocketOptions options) { - var connection = await GetConnectionAsync(context); + var connection = await GetConnectionAsync(context, options); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -505,7 +505,7 @@ namespace Microsoft.AspNetCore.Sockets return true; } - private async Task GetConnectionAsync(HttpContext context) + private async Task GetConnectionAsync(HttpContext context, HttpSocketOptions options) { var connectionId = GetConnectionId(context); @@ -527,16 +527,25 @@ namespace Microsoft.AspNetCore.Sockets return null; } + EnsureConnectionStateInternal(connection, options); + return connection; } - private DefaultConnectionContext CreateConnectionInternal(HttpSocketOptions options) + private void EnsureConnectionStateInternal(DefaultConnectionContext connection, HttpSocketOptions options) { - var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - return _manager.CreateConnection(transportPipeOptions, appPipeOptions); + // If the connection doesn't have a pipe yet then create one, we lazily create the pipe to save on allocations until the client actually connects + if (connection.Transport == null) + { + var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); + var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); + connection.Transport = pair.Application; + connection.Application = pair.Transport; + } } + // This is only used for WebSockets connections, which can connect directly without negotiating private async Task GetOrCreateConnectionAsync(HttpContext context, HttpSocketOptions options) { var connectionId = GetConnectionId(context); @@ -545,7 +554,7 @@ namespace Microsoft.AspNetCore.Sockets // There's no connection id so this is a brand new connection if (StringValues.IsNullOrEmpty(connectionId)) { - connection = CreateConnectionInternal(options); + connection = _manager.CreateConnection(); } else if (!_manager.TryGetConnection(connectionId, out connection)) { @@ -555,6 +564,8 @@ namespace Microsoft.AspNetCore.Sockets return null; } + EnsureConnectionStateInternal(connection, options); + return connection; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index cf115bd037..6a97d4da6e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -10,7 +10,7 @@ - + @@ -22,6 +22,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 742e4c64ac..e0ff217dfe 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; @@ -22,8 +23,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Null(connection.ApplicationTask); Assert.Null(connection.TransportTask); Assert.Null(connection.Cancellation); - Assert.NotEqual(default(DateTime), connection.LastSeenUtc); - Assert.NotNull(connection.Transport); + Assert.NotEqual(default, connection.LastSeenUtc); + Assert.Null(connection.Transport); } [Fact] @@ -42,7 +43,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public void AddNewConnection() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); var transport = connection.Transport; @@ -58,7 +59,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public void RemoveConnection() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); var transport = connection.Transport; @@ -77,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task CloseConnectionsEndsAllPendingConnections() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); connection.ApplicationTask = Task.Run(async () => { @@ -89,7 +90,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests } finally { - connection.Transport.Input.AdvanceTo(result.Buffer.End); + connection.Transport.Input.AdvanceTo(result.Buffer.End); } }); @@ -115,7 +116,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; @@ -135,7 +136,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposingConnectionMultipleGetsExceptionFromTransportOrApp() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; @@ -159,7 +160,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposingConnectionMultipleGetsCancellation() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; @@ -180,7 +181,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposeInactiveConnection() { var connectionManager = CreateConnectionManager(); - var connection = connectionManager.CreateConnection();; + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); Assert.NotNull(connection.ConnectionId); Assert.NotNull(connection.Transport); @@ -209,7 +210,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests appLifetime.Start(); - var connection = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(PipeOptions.Default, PipeOptions.Default); connection.Application.Output.OnReaderCompleted((error, state) => { diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index b0c6245375..b129825ba4 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -77,7 +77,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests await dispatcher.ExecuteNegotiateAsync(context, httpSocketOptions); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var connectionId = negotiateResponse.Value("connectionId"); + context.Request.QueryString = context.Request.QueryString.Add("id", connectionId); Assert.True(manager.TryGetConnection(connectionId, out var connection)); + // Fake actual connection after negotiate to populate the pipes on the connection + await dispatcher.ExecuteAsync(context, httpSocketOptions, c => Task.CompletedTask); // This write should complete immediately but it exceeds the writer threshold var writeTask = connection.Application.Output.WriteAsync(new byte[] { (byte)'b', (byte)'y', (byte)'t', (byte)'e', (byte)'s' }); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index c62d55a357..ddf8727bc8 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -21,6 +21,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); + var context = new DefaultHttpContext(); var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index 0ecb997412..dd251b1231 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -4,7 +4,6 @@ using System.IO; using System.IO.Pipelines; using System.Text; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -58,7 +57,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var ms = new MemoryStream(); context.Response.Body = ms; var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 63918075e3..cec8ea1e05 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -320,6 +320,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests // We want to verify behavior without timeout affecting it CloseTimeout = TimeSpan.FromSeconds(20) }; + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); From db0dc0f96019d6658b6758cfd748281fd809c317 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 16 Mar 2018 16:48:05 -0700 Subject: [PATCH 3/3] Ignore writeasync failures when sending to multiple connections (#1589) --- .../BroadcastBenchmark.cs | 3 +- .../DefaultHubDispatcherBenchmark.cs | 2 +- .../DefaultHubLifetimeManager.cs | 42 ++++++++++-- .../Internal/RedisLoggerExtensions.cs | 22 +++---- .../RedisHubLifetimeManager.cs | 66 +++++++------------ .../RedisHubLifetimeManagerTests.cs | 5 +- .../DefaultHubLifetimeManagerTests.cs | 65 +++++++++++++----- 7 files changed, 123 insertions(+), 82 deletions(-) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 042b5034ed..8133b10e55 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -7,6 +7,7 @@ using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks @@ -25,7 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - _hubLifetimeManager = new DefaultHubLifetimeManager(); + _hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); IHubProtocol protocol; diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 5799673af7..ca1ef0f128 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _dispatcher = new DefaultHubDispatcher( serviceScopeFactory, - new HubContext(new DefaultHubLifetimeManager()), + new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), new Logger>(NullLoggerFactory.Instance)); var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 67cf01085f..03f07a5b6d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { @@ -13,6 +14,12 @@ namespace Microsoft.AspNetCore.SignalR { private readonly HubConnectionList _connections = new HubConnectionList(); private readonly HubGroupList _groups = new HubGroupList(); + private readonly ILogger _logger; + + public DefaultHubLifetimeManager(ILogger> logger) + { + _logger = logger; + } public override Task AddGroupAsync(string connectionId, string groupName) { @@ -83,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR continue; } - tasks.Add(connection.WriteAsync(message)); + tasks.Add(SafeWriteAsync(connection, message)); } return Task.WhenAll(tasks); @@ -105,7 +112,7 @@ namespace Microsoft.AspNetCore.SignalR var message = CreateInvocationMessage(methodName, args); - return connection.WriteAsync(message); + return SafeWriteAsync(connection, message); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) @@ -119,7 +126,7 @@ namespace Microsoft.AspNetCore.SignalR if (group != null) { var message = CreateInvocationMessage(methodName, args); - var tasks = group.Values.Select(c => c.WriteAsync(message)); + var tasks = group.Values.Select(c => SafeWriteAsync(c, message)); return Task.WhenAll(tasks); } @@ -142,7 +149,7 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - tasks.Add(Task.WhenAll(group.Values.Select(c => c.WriteAsync(message)))); + tasks.Add(Task.WhenAll(group.Values.Select(c => SafeWriteAsync(c, message)))); } } @@ -161,7 +168,7 @@ namespace Microsoft.AspNetCore.SignalR { var message = CreateInvocationMessage(methodName, args); var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId)) - .Select(c => c.WriteAsync(message)); + .Select(c => SafeWriteAsync(c, message)); return Task.WhenAll(tasks); } @@ -215,5 +222,30 @@ namespace Microsoft.AspNetCore.SignalR return userIds.Contains(connection.UserIdentifier); }); } + + // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to + private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) + { + try + { + await connection.WriteAsync(message); + } + // This exception isn't interesting to users + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); + } + } + + private static class Log + { + private static readonly Action _failedWritingMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(1, "FailedWritingMessage"), "Failed writing message."); + + public static void FailedWritingMessage(ILogger logger, Exception exception) + { + _failedWritingMessage(logger, exception); + } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs index 19e9791072..e8862cf0e1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs @@ -12,37 +12,37 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal { // Category: RedisHubLifetimeManager private static readonly Action _connectingToEndpoints = - LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(ConnectingToEndpoints)), "Connecting to Redis endpoints: {Endpoints}."); + LoggerMessage.Define(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}."); private static readonly Action _connected = - LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(Connected)), "Connected to Redis."); + LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis."); private static readonly Action _subscribing = - LoggerMessage.Define(LogLevel.Trace, new EventId(3, nameof(Subscribing)), "Subscribing to channel: {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "Subscribing"), "Subscribing to channel: {Channel}."); private static readonly Action _receivedFromChannel = - LoggerMessage.Define(LogLevel.Trace, new EventId(4, nameof(ReceivedFromChannel)), "Received message from Redis channel {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "ReceivedFromChannel"), "Received message from Redis channel {Channel}."); private static readonly Action _publishToChannel = - LoggerMessage.Define(LogLevel.Trace, new EventId(5, nameof(PublishToChannel)), "Publishing message to Redis channel {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "PublishToChannel"), "Publishing message to Redis channel {Channel}."); private static readonly Action _unsubscribe = - LoggerMessage.Define(LogLevel.Trace, new EventId(6, nameof(Unsubscribe)), "Unsubscribing from channel: {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "Unsubscribe"), "Unsubscribing from channel: {Channel}."); private static readonly Action _notConnected = - LoggerMessage.Define(LogLevel.Warning, new EventId(7, nameof(Connected)), "Not connected to Redis."); + LoggerMessage.Define(LogLevel.Warning, new EventId(7, "Connected"), "Not connected to Redis."); private static readonly Action _connectionRestored = - LoggerMessage.Define(LogLevel.Information, new EventId(8, nameof(ConnectionRestored)), "Connection to Redis restored."); + LoggerMessage.Define(LogLevel.Information, new EventId(8, "ConnectionRestored"), "Connection to Redis restored."); private static readonly Action _connectionFailed = - LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(ConnectionFailed)), "Connection to Redis failed."); + LoggerMessage.Define(LogLevel.Warning, new EventId(9, "ConnectionFailed"), "Connection to Redis failed."); private static readonly Action _failedWritingMessage = - LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(FailedWritingMessage)), "Failed writing message."); + LoggerMessage.Define(LogLevel.Warning, new EventId(10, "FailedWritingMessage"), "Failed writing message."); private static readonly Action _internalMessageFailed = - LoggerMessage.Define(LogLevel.Warning, new EventId(11, nameof(InternalMessageFailed)), "Error processing message for internal server message."); + LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message."); public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints) { diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 13d0d4d819..125924eeb0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis var connection = _connections[connectionId]; if (connection != null) { - return connection.WriteAsync(message.CreateInvocation()); + return SafeWriteAsync(connection, message.CreateInvocation()); } return PublishAsync(_channelNamePrefix + "." + connectionId, message); @@ -402,14 +402,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis var invocation = message.CreateInvocation(); foreach (var connection in _connections) { - try - { - tasks.Add(connection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(connection, invocation)); } await Task.WhenAll(tasks); @@ -441,14 +434,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { if (!excludedIds.Contains(connection.ConnectionId)) { - try - { - tasks.Add(connection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(connection, invocation)); } } @@ -524,16 +510,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis _logger.Subscribing(connectionChannel); return _bus.SubscribeAsync(connectionChannel, async (c, data) => { - try - { - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); - await connection.WriteAsync(message.CreateInvocation()); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + await SafeWriteAsync(connection, message.CreateInvocation()); }); } @@ -545,16 +524,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis // TODO: Look at optimizing (looping over connections checking for Name) return _bus.SubscribeAsync(userChannel, async (c, data) => { - try - { - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); - await connection.WriteAsync(message.CreateInvocation()); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + await SafeWriteAsync(connection, message.CreateInvocation()); }); } @@ -576,14 +548,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis continue; } - try - { - tasks.Add(groupConnection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(groupConnection, invocation)); } await Task.WhenAll(tasks); @@ -611,7 +576,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // This also saves serializing and deserializing the message! if (connection != null) { - publishTasks.Add(connection.WriteAsync(message.CreateInvocation())); + publishTasks.Add(SafeWriteAsync(connection, message.CreateInvocation())); } else { @@ -662,6 +627,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.CompletedTask; } + // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to + private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) + { + try + { + await connection.WriteAsync(message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + } + private class LoggerTextWriter : TextWriter { private readonly ILogger _logger; diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index 7f42a38095..7a6aeca22b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } [Fact] - public async Task WritingToLocalConnectionThatFailsThrowsException() + public async Task WritingToLocalConnectionThatFailsDoesNotThrowException() { var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { @@ -519,8 +519,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.OnConnectedAsync(connection).OrTimeout(); - var exception = await Assert.ThrowsAsync(() => manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); - Assert.Equal("Message", exception.Message); + await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index f298fc24eb..dd87b3527f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -3,6 +3,8 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Moq; using Xunit; @@ -11,12 +13,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests public class DefaultHubLifetimeManagerTests { [Fact] - public async Task InvokeAllAsyncWritesToAllConnectionsOutput() + public async Task SendAllAsyncWritesToAllConnectionsOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -38,12 +40,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() + public async Task SendAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -64,12 +66,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() + public async Task SendGroupAsyncWritesToAllConnectionsInGroupOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -90,11 +92,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeConnectionAsyncWritesToConnectionOutput() + public async Task SendConnectionAsyncWritesToConnectionOutput() { using (var client = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -109,42 +111,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite() + public async Task SendConnectionAsyncDoesNotThrowIfConnectionFailsToWrite() { using (var client = new TestClient()) { - // Force an exception when writing to connection - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + // Force an exception when writing to connection connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); var connection = connectionMock.Object; await manager.OnConnectedAsync(connection).OrTimeout(); - var exception = await Assert.ThrowsAsync(() => manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); - Assert.Equal("Message", exception.Message); + await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); } } [Fact] - public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops() + public async Task SendAllAsyncSendsToAllConnectionsEvenWhenSomeFailToSend() { - var manager = new DefaultHubLifetimeManager(); + using (var client = new TestClient()) + using (var client2 = new TestClient()) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + + var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + var connectionMock2 = HubConnectionContextUtils.CreateMock(client2.Connection); + + var tcs = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + // Force an exception when writing to connection + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs.TrySetResult(null)).Throws(new Exception("Message")); + connectionMock2.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs2.TrySetResult(null)).Throws(new Exception("Message")); + var connection = connectionMock.Object; + var connection2 = connectionMock2.Object; + + await manager.OnConnectedAsync(connection).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout(); + + // Check that all connections were "written" to + await tcs.Task.OrTimeout(); + await tcs2.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendConnectionAsyncOnNonExistentConnectionNoops() + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.SendConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout(); } [Fact] public async Task AddGroupOnNonExistentConnectionNoops() { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); } [Fact] public async Task RemoveGroupOnNonExistentConnectionNoops() { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); }