diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 0ddc54dbc1..2913578a42 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -16,7 +16,6 @@ using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; -using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; @@ -31,6 +30,7 @@ namespace Microsoft.AspNetCore.SignalR private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); private readonly ConnectionContext _connectionContext; + private readonly Channel _output; private readonly ILogger _logger; private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); @@ -39,9 +39,17 @@ namespace Microsoft.AspNetCore.SignalR private Task _writingTask = Task.CompletedTask; private long _lastSendTimestamp = Stopwatch.GetTimestamp(); - public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) + public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory): + this(connectionContext, keepAliveInterval, loggerFactory, Channel.CreateUnbounded()) { - Output = Channel.CreateUnbounded(); + } + + internal HubConnectionContext(ConnectionContext connectionContext, + TimeSpan keepAliveInterval, + ILoggerFactory loggerFactory, + Channel output) + { + _output = output; _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); ConnectionAbortedToken = _connectionAbortedTokenSource.Token; @@ -58,13 +66,11 @@ namespace Microsoft.AspNetCore.SignalR public virtual IDictionary Metadata => _connectionContext.Metadata; - public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - public virtual PipeReader Input => _connectionContext.Transport.Input; public string UserIdentifier { get; private set; } - internal virtual Channel Output { get; set; } + internal virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } internal ExceptionDispatchInfo AbortException { get; private set; } @@ -79,21 +85,28 @@ namespace Microsoft.AspNetCore.SignalR public int? LocalPort => Features.Get()?.LocalPort; - public async Task WriteAsync(HubInvocationMessage message) + public async Task WriteAsync(HubMessage message, bool throwOnFailure = false) { - while (await Output.Writer.WaitToWriteAsync()) + while (await _output.Writer.WaitToWriteAsync()) { - if (Output.Writer.TryWrite(message)) + if (_output.Writer.TryWrite(message)) { return; } } + + _logger.OutboundChannelClosed(); + + if (throwOnFailure) + { + throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); + } } public async Task DisposeAsync() { // Nothing should be writing to the HubConnectionContext - Output.Writer.TryComplete(); + _output.Writer.TryComplete(); // This should unwind once we complete the output await _writingTask; @@ -201,17 +214,18 @@ namespace Microsoft.AspNetCore.SignalR private async Task StartAsyncCore() { + Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called"); + if (Features.Get() == null) { - Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called"); - _connectionContext.Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); + Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); } try { - while (await Output.Reader.WaitToReadAsync()) + while (await _output.Reader.WaitToReadAsync()) { - while (Output.Reader.TryRead(out var hubMessage)) + while (_output.Reader.TryRead(out var hubMessage)) { var buffer = ProtocolReaderWriter.WriteMessage(hubMessage); @@ -242,7 +256,7 @@ namespace Microsoft.AspNetCore.SignalR // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. - if (Output.Writer.TryWrite(PingMessage.Instance)) + if (_output.Writer.TryWrite(PingMessage.Instance)) { _logger.SentPing(); } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 2ceacf21fa..8c87c0abed 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -275,19 +275,9 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage) + private Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage) { - while (await connection.Output.Writer.WaitToWriteAsync()) - { - if (connection.Output.Writer.TryWrite(hubMessage)) - { - return; - } - } - - // Output is closed. Cancel this invocation completely - _logger.OutboundChannelClosed(); - throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); + return connection.WriteAsync(hubMessage, throwOnFailure: true); } private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs index 66fc1f6fe9..e90a65510d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs @@ -15,11 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public static HubConnectionContext Create(DefaultConnectionContext connection, Channel replacementOutput = null) { - var context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance); + HubConnectionContext context = null; if (replacementOutput != null) { - context.Output = replacementOutput; + context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance, replacementOutput); } + else + { + context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance); + } + context.ProtocolReaderWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder()); _ = context.StartAsync();