diff --git a/client-ts/signalr-protocol-msgpack/package-lock.json b/client-ts/signalr-protocol-msgpack/package-lock.json index 4167e7a3cf..532a81cc67 100644 --- a/client-ts/signalr-protocol-msgpack/package-lock.json +++ b/client-ts/signalr-protocol-msgpack/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr-protocol-msgpack", - "version": "1.0.0-preview1-t000", + "version": "1.0.0-preview2-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/client-ts/signalr/package-lock.json b/client-ts/signalr/package-lock.json index 49d499908f..264748f026 100644 --- a/client-ts/signalr/package-lock.json +++ b/client-ts/signalr/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr", - "version": "1.0.0-preview1-t000", + "version": "1.0.0-preview2-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs index 9bb8ef6cd5..f633cab095 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs @@ -50,5 +50,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal return _dataEncoder.Encode(ms.ToArray()); } } + + public override bool Equals(object obj) + { + var readerWriter = obj as HubProtocolReaderWriter; + if (readerWriter == null) + { + return false; + } + + // Note: ReferenceEquals on HubProtocol works for our implementation of IHubProtocolResolver because we use Singletons from DI + // However if someone replaces the implementation and returns a new ProtocolResolver for every connection they wont get the perf benefits + // Memory growth is mitigated by capping the cache size + return ReferenceEquals(_dataEncoder, readerWriter._dataEncoder) && ReferenceEquals(_hubProtocol, readerWriter._hubProtocol); + } + + // This should never be used, needed because you can't override Equals without it + public override int GetHashCode() + { + return base.GetHashCode(); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs index 7ea9067b92..196ec7d44e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; + namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public abstract class HubMessage @@ -8,5 +10,42 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol protected HubMessage() { } + + // Initialize with capacity 4 for the 2 built in protocols and 2 data encoders + private readonly List _serializedMessages = new List(4); + + public byte[] WriteMessage(HubProtocolReaderWriter protocolReaderWriter) + { + for (var i = 0; i < _serializedMessages.Count; i++) + { + if (_serializedMessages[i].ProtocolReaderWriter.Equals(protocolReaderWriter)) + { + return _serializedMessages[i].Message; + } + } + + var bytes = protocolReaderWriter.WriteMessage(this); + + // We don't want to balloon memory if someone writes a poor IHubProtocolResolver + // So we cap how many caches we store and worst case just serialize the message for every connection + if (_serializedMessages.Count < 10) + { + _serializedMessages.Add(new SerializedMessage(protocolReaderWriter, bytes)); + } + + return bytes; + } + + private readonly struct SerializedMessage + { + public readonly HubProtocolReaderWriter ProtocolReaderWriter; + public readonly byte[] Message; + + public SerializedMessage(HubProtocolReaderWriter protocolReaderWriter, byte[] message) + { + ProtocolReaderWriter = protocolReaderWriter; + Message = message; + } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 2611ce7263..5c8b1638dd 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -76,7 +76,6 @@ namespace Microsoft.AspNetCore.SignalR var tasks = new List(count); var message = CreateInvocationMessage(methodName, args); - // TODO: serialize once per format by providing a different stream? foreach (var connection in _connections) { if (!include(connection)) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 2b8b73eb41..8702373375 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -39,6 +39,7 @@ namespace Microsoft.AspNetCore.SignalR private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); private long _lastSendTimestamp = Stopwatch.GetTimestamp(); + private byte[] _cachedPingMessage; public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { @@ -46,11 +47,6 @@ namespace Microsoft.AspNetCore.SignalR _logger = loggerFactory.CreateLogger(); ConnectionAbortedToken = _connectionAbortedTokenSource.Token; _keepAliveDuration = (int)keepAliveInterval.TotalMilliseconds * (Stopwatch.Frequency / 1000); - - if (Features.Get() == null) - { - Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); - } } public virtual CancellationToken ConnectionAbortedToken { get; } @@ -84,17 +80,42 @@ namespace Microsoft.AspNetCore.SignalR public virtual async Task WriteAsync(HubMessage message) { + await _writeLock.WaitAsync(); + try { - await _writeLock.WaitAsync(); - - var buffer = ProtocolReaderWriter.WriteMessage(message); - + // This will internally cache the buffer for each unique HubProtocol/DataEncoder combination + // So that we don't serialize the HubMessage for every single connection + var buffer = message.WriteMessage(ProtocolReaderWriter); _connectionContext.Transport.Output.Write(buffer); Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); - await _connectionContext.Transport.Output.FlushAsync(CancellationToken.None); + await _connectionContext.Transport.Output.FlushAsync(); + } + finally + { + _writeLock.Release(); + } + } + + private async Task TryWritePingAsync() + { + // Don't wait for the lock, if it returns false that means someone wrote to the connection + // and we don't need to send a ping anymore + if (!await _writeLock.WaitAsync(0)) + { + return; + } + + try + { + Debug.Assert(_cachedPingMessage != null); + _connectionContext.Transport.Output.Write(_cachedPingMessage); + + Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); + + await _connectionContext.Transport.Output.FlushAsync(); } finally { @@ -154,11 +175,18 @@ namespace Microsoft.AspNetCore.SignalR : TransferMode.Text; ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); + _cachedPingMessage = ProtocolReaderWriter.WriteMessage(PingMessage.Instance); Log.UsingHubProtocol(_logger, protocol.Name); UserIdentifier = userIdProvider.GetUserId(this); + if (Features.Get() == null) + { + // Only register KeepAlive after protocol negotiated otherwise KeepAliveTick could try to write without having a ProtocolReaderWriter + Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); + } + return true; } } @@ -210,11 +238,8 @@ namespace Microsoft.AspNetCore.SignalR // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. + _ = TryWritePingAsync(); Log.SentPing(_logger); - - _ = WriteAsync(PingMessage.Instance); - - Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 9178b2a568..954ca28534 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; @@ -27,7 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public DefaultConnectionContext Connection { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; - public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) + public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IDataEncoder dataEncoder = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { var options = new PipeOptions(readerScheduler: synchronousCallbacks ? PipeScheduler.Inline : null); var pair = DuplexPipe.CreateConnectionPair(options, options); @@ -44,7 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); protocol = protocol ?? new JsonHubProtocol(); - _protocolReaderWriter = new HubProtocolReaderWriter(protocol, new PassThroughEncoder()); + dataEncoder = dataEncoder ?? new PassThroughEncoder(); + _protocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); _invocationBinder = invocationBinder ?? new DefaultInvocationBinder(); _cts = new CancellationTokenSource(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 21e6ff20c3..d50c33a0ff 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -9,6 +9,7 @@ using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils; using Microsoft.AspNetCore.Sockets; @@ -1413,6 +1414,43 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task CanSendToConnectionsWithDifferentProtocols() + { + var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(); + var endPoint = serviceProvider.GetService>(); + + using (var client1 = new TestClient(protocol: new JsonHubProtocol())) + using (var client2 = new TestClient(protocol: new MessagePackHubProtocol(), dataEncoder: new Base64Encoder())) + { + var endPointLifetime1 = endPoint.OnConnectedAsync(client1.Connection); + var endPointLifetime2 = endPoint.OnConnectedAsync(client2.Connection); + + await client1.Connected.OrTimeout(); + await client2.Connected.OrTimeout(); + + var sentMessage = "From Json"; + + await client1.SendInvocationAsync(nameof(MethodHub.BroadcastMethod), sentMessage); + var message1 = await client1.ReadAsync().OrTimeout(); + var message2 = await client2.ReadAsync().OrTimeout(); + + var completion1 = message1 as InvocationMessage; + Assert.NotNull(completion1); + Assert.Equal(sentMessage, completion1.Arguments[0]); + var completion2 = message2 as InvocationMessage; + Assert.NotNull(completion2); + // Argument[0] is a 'MsgPackObject' with a string internally, ToString to compare it + Assert.Equal(sentMessage, completion2.Arguments[0].ToString()); + + client1.Dispose(); + client2.Dispose(); + + await endPointLifetime1.OrTimeout(); + await endPointLifetime2.OrTimeout(); + } + } + public static IEnumerable StreamingMethodAndHubProtocols { get