Don't expose Channel from HubConnectionContext (#1428)

- Change HubEndPoint to call WriteAsync
- Fixed assert  for protocol reader writer
This commit is contained in:
David Fowler 2018-02-09 22:00:28 -08:00 committed by GitHub
parent 28439d1441
commit 2ed78d5762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 29 deletions

View File

@ -16,7 +16,6 @@ using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
@ -31,6 +30,7 @@ namespace Microsoft.AspNetCore.SignalR
private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder();
private readonly ConnectionContext _connectionContext;
private readonly Channel<HubMessage> _output;
private readonly ILogger _logger;
private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource();
private readonly TaskCompletionSource<object> _abortCompletedTcs = new TaskCompletionSource<object>();
@ -39,9 +39,17 @@ namespace Microsoft.AspNetCore.SignalR
private Task _writingTask = Task.CompletedTask;
private long _lastSendTimestamp = Stopwatch.GetTimestamp();
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory):
this(connectionContext, keepAliveInterval, loggerFactory, Channel.CreateUnbounded<HubMessage>())
{
Output = Channel.CreateUnbounded<HubMessage>();
}
internal HubConnectionContext(ConnectionContext connectionContext,
TimeSpan keepAliveInterval,
ILoggerFactory loggerFactory,
Channel<HubMessage> output)
{
_output = output;
_connectionContext = connectionContext;
_logger = loggerFactory.CreateLogger<HubConnectionContext>();
ConnectionAbortedToken = _connectionAbortedTokenSource.Token;
@ -58,13 +66,11 @@ namespace Microsoft.AspNetCore.SignalR
public virtual IDictionary<object, object> Metadata => _connectionContext.Metadata;
public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
public virtual PipeReader Input => _connectionContext.Transport.Input;
public string UserIdentifier { get; private set; }
internal virtual Channel<HubMessage> Output { get; set; }
internal virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
internal ExceptionDispatchInfo AbortException { get; private set; }
@ -79,21 +85,28 @@ namespace Microsoft.AspNetCore.SignalR
public int? LocalPort => Features.Get<IHttpConnectionFeature>()?.LocalPort;
public async Task WriteAsync(HubInvocationMessage message)
public async Task WriteAsync(HubMessage message, bool throwOnFailure = false)
{
while (await Output.Writer.WaitToWriteAsync())
while (await _output.Writer.WaitToWriteAsync())
{
if (Output.Writer.TryWrite(message))
if (_output.Writer.TryWrite(message))
{
return;
}
}
_logger.OutboundChannelClosed();
if (throwOnFailure)
{
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
}
}
public async Task DisposeAsync()
{
// Nothing should be writing to the HubConnectionContext
Output.Writer.TryComplete();
_output.Writer.TryComplete();
// This should unwind once we complete the output
await _writingTask;
@ -201,17 +214,18 @@ namespace Microsoft.AspNetCore.SignalR
private async Task StartAsyncCore()
{
Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called");
if (Features.Get<IConnectionInherentKeepAliveFeature>() == null)
{
Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called");
_connectionContext.Features.Get<IConnectionHeartbeatFeature>()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this);
Features.Get<IConnectionHeartbeatFeature>()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this);
}
try
{
while (await Output.Reader.WaitToReadAsync())
while (await _output.Reader.WaitToReadAsync())
{
while (Output.Reader.TryRead(out var hubMessage))
while (_output.Reader.TryRead(out var hubMessage))
{
var buffer = ProtocolReaderWriter.WriteMessage(hubMessage);
@ -242,7 +256,7 @@ namespace Microsoft.AspNetCore.SignalR
// adding a Ping message when the transport is full is unnecessary since the
// transport is still in the process of sending frames.
if (Output.Writer.TryWrite(PingMessage.Instance))
if (_output.Writer.TryWrite(PingMessage.Instance))
{
_logger.SentPing();
}

View File

@ -275,19 +275,9 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
private Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
while (await connection.Output.Writer.WaitToWriteAsync())
{
if (connection.Output.Writer.TryWrite(hubMessage))
{
return;
}
}
// Output is closed. Cancel this invocation completely
_logger.OutboundChannelClosed();
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
return connection.WriteAsync(hubMessage, throwOnFailure: true);
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,

View File

@ -15,11 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public static HubConnectionContext Create(DefaultConnectionContext connection, Channel<HubMessage> replacementOutput = null)
{
var context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance);
HubConnectionContext context = null;
if (replacementOutput != null)
{
context.Output = replacementOutput;
context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance, replacementOutput);
}
else
{
context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance);
}
context.ProtocolReaderWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder());
_ = context.StartAsync();