Client pipelines (#1435)

- Reworked the Client to be based on pipelines instead of Channels
- SendAsync no longer fails if the http request itself fails but the connection is closed as a result.
- Updated tests
- Base64Encoder needed to support multiple messages in the same span of data
This commit is contained in:
David Fowler 2018-02-12 22:27:43 -08:00 committed by GitHub
parent 0159b53e5e
commit 6c22f25818
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 483 additions and 500 deletions

View File

@ -430,7 +430,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
return;
}
//TODO: Optimize this!
// TODO: Optimize this!
// Copying the callbacks to avoid concurrency issues
InvocationHandler[] copiedHandlers;
lock (handlers)

View File

@ -10,16 +10,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
{
public class Base64Encoder : IDataEncoder
{
public ReadOnlySpan<byte> Decode(byte[] payload)
public bool TryDecode(ref ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> data)
{
ReadOnlySpan<byte> buffer = payload;
LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message);
Span<byte> decoded = new byte[Base64.GetMaxDecodedFromUtf8Length(message.Length)];
var status = Base64.DecodeFromUtf8(message, decoded, out _, out var written);
Debug.Assert(status == OperationStatus.Done);
return decoded.Slice(0, written);
if (LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message))
{
Span<byte> decoded = new byte[Base64.GetMaxDecodedFromUtf8Length(message.Length)];
var status = Base64.DecodeFromUtf8(message, decoded, out _, out var written);
Debug.Assert(status == OperationStatus.Done);
data = decoded.Slice(0, written);
return true;
}
return false;
}
private const int Int32OverflowLength = 10;

View File

@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
public interface IDataEncoder
{
byte[] Encode(byte[] payload);
ReadOnlySpan<byte> Decode(byte[] payload);
bool TryDecode(ref ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> data);
}
}

View File

@ -7,9 +7,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
{
public class PassThroughEncoder : IDataEncoder
{
public ReadOnlySpan<byte> Decode(byte[] payload)
public bool TryDecode(ref ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> data)
{
return payload;
data = buffer;
buffer = Array.Empty<byte>();
return true;
}
public byte[] Encode(byte[] payload)

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
@ -32,8 +33,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public bool ReadMessages(byte[] input, IInvocationBinder binder, out IList<HubMessage> messages)
{
var buffer = _dataEncoder.Decode(input);
return _hubProtocol.TryParseMessages(buffer, binder, out messages);
messages = new List<HubMessage>();
ReadOnlySpan<byte> span = input;
while (span.Length > 0 && _dataEncoder.TryDecode(ref span, out var data))
{
_hubProtocol.TryParseMessages(data, binder, messages);
}
return messages.Count > 0;
}
public byte[] WriteMessage(HubMessage hubMessage)

View File

@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
ProtocolType Type { get; }
bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, out IList<HubMessage> messages);
bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages);
void WriteMessage(HubMessage message, Stream output);
}

View File

@ -43,10 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public ProtocolType Type => ProtocolType.Text;
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, out IList<HubMessage> messages)
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
{
messages = new List<HubMessage>();
while (TextMessageParser.TryParseMessage(ref input, out var payload))
{
// TODO: Need a span-native JSON parser!

View File

@ -35,10 +35,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
SerializationContext = options.Value.SerializationContext;
}
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, out IList<HubMessage> messages)
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
{
messages = new List<HubMessage>();
while (BinaryMessageParser.TryParseMessage(ref input, out var payload))
{
using (var memoryStream = new MemoryStream(payload.ToArray()))

View File

@ -3,12 +3,11 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
@ -31,7 +30,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
private volatile ConnectionState _connectionState = ConnectionState.Disconnected;
private readonly object _stateChangeLock = new object();
private volatile ChannelConnection<byte[], SendMessage> _transportChannel;
private volatile IDuplexPipe _transportChannel;
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private volatile ITransport _transport;
@ -43,8 +42,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
private string _connectionId;
private Exception _abortException;
private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5);
private ChannelReader<byte[]> Input => _transportChannel.Input;
private ChannelWriter<SendMessage> Output => _transportChannel.Output;
private PipeReader Input => _transportChannel.Input;
private PipeWriter Output => _transportChannel.Output;
private readonly List<ReceiveCallback> _callbacks = new List<ReceiveCallback>();
private readonly TransportType _requestedTransportType = TransportType.All;
private readonly ConnectionLogScope _logScope;
@ -187,7 +186,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_closeTcs = new TaskCompletionSource<object>();
_ = Input.Completion.ContinueWith(async t =>
Input.OnWriterCompleted(async (exception, state) =>
{
// Grab the exception and then clear it.
// See comment at AbortAsync for more discussion on the thread-safety
@ -221,9 +220,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
try
{
if (t.IsFaulted)
if (exception != null)
{
Closed?.Invoke(t.Exception.InnerException);
Closed?.Invoke(exception);
}
else
{
@ -237,7 +236,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
// Suppress (but log) the exception, this is user code
_logger.ErrorDuringClosedEvent(ex);
}
});
}, null);
_receiveLoopTask = ReceiveAsync();
}
@ -325,15 +325,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
private async Task StartTransport(Uri connectUrl)
{
var applicationToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationSide = ChannelConnection.Create(applicationToTransport, transportToApplication);
_transportChannel = ChannelConnection.Create(transportToApplication, applicationToTransport);
var options = new PipeOptions(readerScheduler: PipeScheduler.ThreadPool);
var pair = DuplexPipe.CreateConnectionPair(options, options);
_transportChannel = pair.Transport;
// Start the transport, giving it one end of the pipeline
try
{
await _transport.StartAsync(connectUrl, applicationSide, GetTransferMode(), this);
await _transport.StartAsync(connectUrl, pair.Application, GetTransferMode(), this);
// actual transfer mode can differ from the one that was requested so set it on the feature
if (!_transport.Mode.HasValue)
@ -379,57 +378,72 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_logger.HttpReceiveStarted();
while (await Input.WaitToReadAsync())
while (true)
{
if (_connectionState != ConnectionState.Connected)
{
_logger.SkipRaisingReceiveEvent();
// drain
Input.TryRead(out _);
continue;
break;
}
if (Input.TryRead(out var buffer))
var result = await Input.ReadAsync();
var buffer = result.Buffer;
try
{
_logger.ScheduleReceiveEvent();
_ = _eventQueue.Enqueue(async () =>
if (!buffer.IsEmpty)
{
_logger.RaiseReceiveEvent();
_logger.ScheduleReceiveEvent();
var data = buffer.ToArray();
// Copying the callbacks to avoid concurrency issues
ReceiveCallback[] callbackCopies;
lock (_callbacks)
_ = _eventQueue.Enqueue(async () =>
{
callbackCopies = new ReceiveCallback[_callbacks.Count];
_callbacks.CopyTo(callbackCopies);
}
_logger.RaiseReceiveEvent();
foreach (var callbackObject in callbackCopies)
{
try
// Copying the callbacks to avoid concurrency issues
ReceiveCallback[] callbackCopies;
lock (_callbacks)
{
await callbackObject.InvokeAsync(buffer);
callbackCopies = new ReceiveCallback[_callbacks.Count];
_callbacks.CopyTo(callbackCopies);
}
catch (Exception ex)
foreach (var callbackObject in callbackCopies)
{
_logger.ExceptionThrownFromCallback(nameof(OnReceived), ex);
try
{
await callbackObject.InvokeAsync(data);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromCallback(nameof(OnReceived), ex);
}
}
}
});
});
}
else if (result.IsCompleted)
{
break;
}
}
else
finally
{
_logger.FailedReadingMessage();
Input.AdvanceTo(buffer.End);
}
}
await Input.Completion;
}
catch (Exception ex)
{
Output.TryComplete(ex);
Input.Complete(ex);
_logger.ErrorReceiving(ex);
}
finally
{
Input.Complete();
}
_logger.EndReceive();
}
@ -450,23 +464,11 @@ namespace Microsoft.AspNetCore.Sockets.Client
"Cannot send messages when the connection is not in the Connected state.");
}
// TaskCreationOptions.RunContinuationsAsynchronously ensures that continuations awaiting
// SendAsync (i.e. user's code) are not running on the same thread as the code that sets
// TaskCompletionSource result. This way we prevent from user's code blocking our channel
// send loop.
var sendTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var message = new SendMessage(data, sendTcs);
_logger.SendingMessage();
while (await Output.WaitToWriteAsync(cancellationToken))
{
if (Output.TryWrite(message))
{
await sendTcs.Task;
break;
}
}
cancellationToken.ThrowIfCancellationRequested();
await Output.WriteAsync(data);
}
// AbortAsync creates a few thread-safety races that we are OK with.
@ -539,7 +541,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
if (_transportChannel != null)
{
Output.TryComplete();
Output.Complete();
}
if (transport != null)

View File

@ -3,13 +3,13 @@
using System;
using System.Threading.Tasks;
using System.Threading.Channels;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Client
{
public interface ITransport
{
Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, IConnection connection);
Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection);
Task StopAsync();
TransferMode? Mode { get; }
}

View File

@ -1,8 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Net.Http;
namespace Microsoft.AspNetCore.Sockets.Client
{
public interface ITransportFactory

View File

@ -47,8 +47,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
private static readonly Action<ILogger, int, Exception> _messageToApp =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(12, nameof(MessageToApp)), "Passing message to application. Payload size: {count}.");
private static readonly Action<ILogger, int, Exception> _receivedFromApp =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(13, nameof(ReceivedFromApp)), "Received message from application. Payload size: {count}.");
private static readonly Action<ILogger, long, Exception> _receivedFromApp =
LoggerMessage.Define<long>(LogLevel.Debug, new EventId(13, nameof(ReceivedFromApp)), "Received message from application. Payload size: {count}.");
private static readonly Action<ILogger, Exception> _sendMessageCanceled =
LoggerMessage.Define(LogLevel.Information, new EventId(14, nameof(SendMessageCanceled)), "Sending a message canceled.");
@ -66,8 +66,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
LoggerMessage.Define(LogLevel.Debug, new EventId(18, nameof(CancelMessage)), "Canceled passing message to application.");
// Category: ServerSentEventsTransport and LongPollingTransport
private static readonly Action<ILogger, int, Uri, Exception> _sendingMessages =
LoggerMessage.Define<int, Uri>(LogLevel.Debug, new EventId(10, nameof(SendingMessages)), "Sending {count} message(s) to the server using url: {url}.");
private static readonly Action<ILogger, long, Uri, Exception> _sendingMessages =
LoggerMessage.Define<long, Uri>(LogLevel.Debug, new EventId(10, nameof(SendingMessages)), "Sending {count} bytes to the server using url: {url}.");
private static readonly Action<ILogger, Exception> _sentSuccessfully =
LoggerMessage.Define(LogLevel.Debug, new EventId(11, nameof(SentSuccessfully)), "Message(s) sent successfully.");
@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
_sendStarted(logger, null);
}
public static void ReceivedFromApp(this ILogger logger, int count)
public static void ReceivedFromApp(this ILogger logger, long count)
{
_receivedFromApp(logger, count, null);
}
@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
_cancelMessage(logger, null);
}
public static void SendingMessages(this ILogger logger, int count, Uri url)
public static void SendingMessages(this ILogger logger, long count, Uri url)
{
_sendingMessages(logger, count, url, null);
}

View File

@ -2,10 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private readonly ILogger _logger;
private Channel<byte[], SendMessage> _application;
private IDuplexPipe _application;
private Task _sender;
private Task _poller;
@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<LongPollingTransport>();
}
public Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, IConnection connection)
public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection)
{
if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text)
{
@ -62,7 +62,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
Running = Task.WhenAll(_sender, _poller).ContinueWith(t =>
{
_logger.TransportStopped(t.Exception?.InnerException);
_application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
_application.Output.Complete(t.Exception?.InnerException);
_application.Input.Complete();
return t;
}).Unwrap();
@ -122,17 +123,11 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_logger.ReceivedMessages();
// Until Pipeline starts natively supporting BytesReader, this is the easiest way to do this.
// TODO: Use CopyToAsync here
var payload = await response.Content.ReadAsByteArrayAsync();
if (payload.Length > 0)
{
while (!_application.Writer.TryWrite(payload))
{
if (cancellationToken.IsCancellationRequested || !await _application.Writer.WaitToWriteAsync(cancellationToken))
{
return;
}
}
await _application.Output.WriteAsync(payload);
}
}
}

View File

@ -2,12 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
using Microsoft.Extensions.Logging;
@ -16,87 +14,61 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
internal static class SendUtils
{
public static async Task SendMessages(Uri sendUrl, Channel<byte[], SendMessage> application, HttpClient httpClient,
public static async Task SendMessages(Uri sendUrl, IDuplexPipe application, HttpClient httpClient,
HttpOptions httpOptions, CancellationTokenSource transportCts, ILogger logger)
{
logger.SendStarted();
IList<SendMessage> messages = null;
try
{
while (await application.Reader.WaitToReadAsync(transportCts.Token))
while (true)
{
// Grab as many messages as we can from the channel
messages = new List<SendMessage>();
while (!transportCts.IsCancellationRequested && application.Reader.TryRead(out SendMessage message))
var result = await application.Input.ReadAsync(transportCts.Token);
var buffer = result.Buffer;
try
{
messages.Add(message);
}
// Grab as many messages as we can from the channel
transportCts.Token.ThrowIfCancellationRequested();
if (messages.Count > 0)
{
logger.SendingMessages(messages.Count, sendUrl);
// Send them in a single post
var request = new HttpRequestMessage(HttpMethod.Post, sendUrl);
PrepareHttpRequest(request, httpOptions);
// TODO: We can probably use a pipeline here or some kind of pooled memory.
// But where do we get the pool from? ArrayBufferPool.Instance?
var memoryStream = new MemoryStream();
foreach (var message in messages)
transportCts.Token.ThrowIfCancellationRequested();
if (!buffer.IsEmpty)
{
if (message.Payload != null)
{
memoryStream.Write(message.Payload, 0, message.Payload.Length);
}
logger.SendingMessages(buffer.Length, sendUrl);
// Send them in a single post
var request = new HttpRequestMessage(HttpMethod.Post, sendUrl);
PrepareHttpRequest(request, httpOptions);
// TODO: Use a custom stream implementation over the ReadOnlyBuffer<byte>
request.Content = new ByteArrayContent(buffer.ToArray());
var response = await httpClient.SendAsync(request, transportCts.Token);
response.EnsureSuccessStatusCode();
logger.SentSuccessfully();
}
memoryStream.Position = 0;
// Set the, now filled, stream as the content
request.Content = new StreamContent(memoryStream);
var response = await httpClient.SendAsync(request, transportCts.Token);
response.EnsureSuccessStatusCode();
logger.SentSuccessfully();
foreach (var message in messages)
else if (result.IsCompleted)
{
message.SendResult?.TrySetResult(null);
break;
}
else
{
logger.NoMessages();
}
}
else
finally
{
logger.NoMessages();
application.Input.AdvanceTo(buffer.End);
}
}
}
catch (OperationCanceledException)
{
// transport is being closed
if (messages != null)
{
foreach (var message in messages)
{
// This will no-op for any messages that were already marked as completed.
message.SendResult?.TrySetCanceled();
}
}
logger.SendCanceled();
}
catch (Exception ex)
{
logger.ErrorSending(sendUrl, ex);
if (messages != null)
{
foreach (var message in messages)
{
// This will no-op for any messages that were already marked as completed.
message.SendResult?.TrySetException(ex);
}
}
throw;
}
finally

View File

@ -2,13 +2,12 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
@ -19,14 +18,13 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
public class ServerSentEventsTransport : ITransport
{
private static readonly MemoryPool _memoryPool = new MemoryPool();
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private readonly ILogger _logger;
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
private readonly ServerSentEventsMessageParser _parser = new ServerSentEventsMessageParser();
private Channel<byte[], SendMessage> _application;
private IDuplexPipe _application;
public Task Running { get; private set; } = Task.CompletedTask;
@ -48,7 +46,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<ServerSentEventsTransport>();
}
public Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, IConnection connection)
public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection)
{
if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text)
{
@ -66,15 +64,16 @@ namespace Microsoft.AspNetCore.Sockets.Client
Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t =>
{
_logger.TransportStopped(t.Exception?.InnerException);
_application.Output.Complete(t.Exception?.InnerException);
_application.Input.Complete();
_application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
return t;
}).Unwrap();
return Task.CompletedTask;
}
private async Task OpenConnection(Channel<byte[], SendMessage> application, Uri url, CancellationToken cancellationToken)
private async Task OpenConnection(IDuplexPipe application, Uri url, CancellationToken cancellationToken)
{
_logger.StartReceive();
@ -83,59 +82,60 @@ namespace Microsoft.AspNetCore.Sockets.Client
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
var stream = await response.Content.ReadAsStreamAsync();
var pipelineReader = StreamPipeConnection.CreateReader(new PipeOptions(_memoryPool), stream);
var readCancellationRegistration = cancellationToken.Register(
reader => ((PipeReader)reader).CancelPendingRead(), pipelineReader);
try
using (var stream = await response.Content.ReadAsStreamAsync())
{
while (true)
var pipelineReader = StreamPipeConnection.CreateReader(PipeOptions.Default, stream);
var readCancellationRegistration = cancellationToken.Register(
reader => ((PipeReader)reader).CancelPendingRead(), pipelineReader);
try
{
var result = await pipelineReader.ReadAsync();
var input = result.Buffer;
if (result.IsCancelled || (input.IsEmpty && result.IsCompleted))
while (true)
{
_logger.EventStreamEnded();
break;
}
var consumed = input.Start;
var examined = input.End;
try
{
var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer);
switch (parseResult)
var result = await pipelineReader.ReadAsync();
var input = result.Buffer;
if (result.IsCancelled || (input.IsEmpty && result.IsCompleted))
{
case ServerSentEventsMessageParser.ParseResult.Completed:
_application.Writer.TryWrite(buffer);
_parser.Reset();
break;
case ServerSentEventsMessageParser.ParseResult.Incomplete:
if (result.IsCompleted)
{
throw new FormatException("Incomplete message.");
}
break;
_logger.EventStreamEnded();
break;
}
var consumed = input.Start;
var examined = input.End;
try
{
var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer);
switch (parseResult)
{
case ServerSentEventsMessageParser.ParseResult.Completed:
await _application.Output.WriteAsync(buffer);
_parser.Reset();
break;
case ServerSentEventsMessageParser.ParseResult.Incomplete:
if (result.IsCompleted)
{
throw new FormatException("Incomplete message.");
}
break;
}
}
finally
{
pipelineReader.AdvanceTo(consumed, examined);
}
}
finally
{
pipelineReader.AdvanceTo(consumed, examined);
}
}
}
catch (OperationCanceledException)
{
_logger.ReceiveCanceled();
}
finally
{
readCancellationRegistration.Dispose();
_transportCts.Cancel();
stream.Dispose();
_logger.ReceiveStopped();
catch (OperationCanceledException)
{
_logger.ReceiveCanceled();
}
finally
{
readCancellationRegistration.Dispose();
_transportCts.Cancel();
_logger.ReceiveStopped();
}
}
}
@ -143,7 +143,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_logger.TransportStopping();
_transportCts.Cancel();
_application.Writer.TryComplete();
try
{

View File

@ -2,11 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
@ -18,7 +16,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
public class WebSocketsTransport : ITransport
{
private readonly ClientWebSocket _webSocket;
private Channel<byte[], SendMessage> _application;
private IDuplexPipe _application;
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
private readonly CancellationTokenSource _receiveCts = new CancellationTokenSource();
private readonly ILogger _logger;
@ -54,7 +52,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<WebSocketsTransport>();
}
public async Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, IConnection connection)
public async Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection)
{
if (url == null)
{
@ -77,8 +75,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger.StartTransport(Mode.Value);
await Connect(url);
var sendTask = SendMessages(url);
var receiveTask = ReceiveMessages(url);
var sendTask = SendMessages();
var receiveTask = ReceiveMessages();
// TODO: Handle TCP connection errors
// https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251
@ -86,84 +84,48 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_webSocket.Dispose();
_logger.TransportStopped(t.Exception?.InnerException);
_application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
_application.Output.Complete(t.Exception?.InnerException);
_application.Input.Complete();
return t;
}).Unwrap();
}
private async Task ReceiveMessages(Uri pollUrl)
private async Task ReceiveMessages()
{
_logger.StartReceive();
try
{
while (!_receiveCts.Token.IsCancellationRequested)
while (true)
{
const int bufferSize = 4096;
var totalBytes = 0;
var incomingMessage = new List<ArraySegment<byte>>();
WebSocketReceiveResult receiveResult;
do
var memory = _application.Output.GetMemory();
// REVIEW: Use new Memory<byte> websocket APIs on .NET Core 2.1
memory.TryGetArray(out var arraySegment);
// Exceptions are handled above where the send and receive tasks are being run.
var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _receiveCts.Token);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
_logger.WebSocketClosed(receiveResult.CloseStatus);
//Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await _webSocket.ReceiveAsync(buffer, _receiveCts.Token);
if (receiveResult.MessageType == WebSocketMessageType.Close)
if (receiveResult.CloseStatus != WebSocketCloseStatus.NormalClosure)
{
_logger.WebSocketClosed(receiveResult.CloseStatus);
_application.Writer.Complete(
receiveResult.CloseStatus == WebSocketCloseStatus.NormalClosure
? null
: new InvalidOperationException(
$"Websocket closed with error: {receiveResult.CloseStatus}."));
return;
throw new InvalidOperationException($"Websocket closed with error: {receiveResult.CloseStatus}.");
}
_logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
//Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
var messageBuffer = new byte[totalBytes];
if (incomingMessage.Count > 1)
{
var offset = 0;
for (var i = 0; i < incomingMessage.Count; i++)
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
}
else
{
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count);
return;
}
try
_logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
_application.Output.Advance(receiveResult.Count);
if (receiveResult.EndOfMessage)
{
if (!_transportCts.Token.IsCancellationRequested)
{
_logger.MessageToApp(messageBuffer.Length);
while (await _application.Writer.WaitToWriteAsync(_transportCts.Token))
{
if (_application.Writer.TryWrite(messageBuffer))
{
incomingMessage.Clear();
break;
}
}
}
}
catch (OperationCanceledException)
{
_logger.CancelMessage();
await _application.Output.FlushAsync(_transportCts.Token);
}
}
}
@ -173,12 +135,13 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
finally
{
// We're done writing
_logger.ReceiveStopped();
_transportCts.Cancel();
}
}
private async Task SendMessages(Uri sendUrl)
private async Task SendMessages()
{
_logger.SendStarted();
@ -189,32 +152,38 @@ namespace Microsoft.AspNetCore.Sockets.Client
try
{
while (await _application.Reader.WaitToReadAsync(_transportCts.Token))
while (true)
{
while (_application.Reader.TryRead(out SendMessage message))
var result = await _application.Input.ReadAsync(_transportCts.Token);
var buffer = result.Buffer;
try
{
try
if (!buffer.IsEmpty)
{
_logger.ReceivedFromApp(message.Payload.Length);
_logger.ReceivedFromApp(buffer.Length);
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload), webSocketMessageType, true, _transportCts.Token);
message.SendResult.SetResult(null);
await _webSocket.SendAsync(new ArraySegment<byte>(buffer.ToArray()), webSocketMessageType, true, _transportCts.Token);
}
catch (OperationCanceledException)
else if (result.IsCompleted)
{
_logger.SendMessageCanceled();
message.SendResult.SetCanceled();
await CloseWebSocket();
break;
}
catch (Exception ex)
{
_logger.ErrorSendingMessage(ex);
message.SendResult.SetException(ex);
await CloseWebSocket();
throw;
}
}
catch (OperationCanceledException)
{
_logger.SendMessageCanceled();
await CloseWebSocket();
break;
}
catch (Exception ex)
{
_logger.ErrorSendingMessage(ex);
await CloseWebSocket();
throw;
}
finally
{
_application.Input.AdvanceTo(buffer.End);
}
}
}

View File

@ -20,11 +20,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common
public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout));
var cts = new CancellationTokenSource();
var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token));
if (completed != task)
{
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
cts.Cancel();
await task;
}
@ -36,11 +38,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common
public static async Task<T> OrTimeout<T>(this Task<T> task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null)
{
var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout));
var cts = new CancellationTokenSource();
var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token));
if (completed != task)
{
throw new TimeoutException(GetMessage(memberName, filePath, lineNumber));
}
cts.Cancel();
return await task;
}

View File

@ -219,7 +219,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task CanInvokeClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path)
{
using (StartLog(out var loggerFactory, $"{nameof(CanInvokeClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanInvokeClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
{
const string originalMessage = "SignalR";
@ -252,7 +252,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path)
{
using (StartLog(out var loggerFactory, $"{nameof(InvokeNonExistantClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(InvokeNonExistantClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, protocol, loggerFactory);
@ -292,7 +292,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task CanStreamClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path)
{
using (StartLog(out var loggerFactory, $"{nameof(CanStreamClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanStreamClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}"))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, protocol, loggerFactory);

View File

@ -214,7 +214,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var httpHandler = new TestHttpMessageHandler();
var longPollResult = new TaskCompletionSource<HttpResponseMessage>();
httpHandler.OnLongPoll(cancellationToken => longPollResult.Task.OrTimeout());
httpHandler.OnLongPoll(cancellationToken =>
{
cancellationToken.Register(() =>
{
longPollResult.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
return longPollResult.Task;
});
httpHandler.OnSocketSend((data, _) =>
{
@ -227,9 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
async (connection, closed) =>
{
await connection.StartAsync().OrTimeout();
await Assert.ThrowsAsync<HttpRequestException>(() => connection.SendAsync(new byte[] { 0x42 }).OrTimeout());
longPollResult.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
await connection.SendAsync(new byte[] { 0x42 }).OrTimeout();
// Wait for the connection to close, because the send failed.
await Assert.ThrowsAsync<HttpRequestException>(() => closed.OrTimeout());
@ -318,7 +323,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
async (connection, closed) =>
{
await connection.StartAsync().OrTimeout();
testTransport.Application.Writer.TryComplete(expected);
testTransport.Application.Output.Complete(expected);
var actual = await Assert.ThrowsAsync<Exception>(() => closed.OrTimeout());
Assert.Same(expected, actual);

View File

@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}, receiveTcs);
await connection.StartAsync().OrTimeout();
Assert.Equal("42", await receiveTcs.Task.OrTimeout());
Assert.Contains("42", await receiveTcs.Task.OrTimeout());
});
}
@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}, receiveTcs);
await connection.StartAsync().OrTimeout();
Assert.Equal("42", await receiveTcs.Task.OrTimeout());
Assert.Contains("42", await receiveTcs.Task.OrTimeout());
Assert.True(receivedRaised);
});
}
@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}, receiveTcs);
await connection.StartAsync().OrTimeout();
Assert.Equal("42", await receiveTcs.Task.OrTimeout());
Assert.Contains("42", await receiveTcs.Task.OrTimeout());
Assert.True(receivedRaised);
});
}

View File

@ -92,13 +92,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
[Fact]
public async Task CallerReceivesExceptionsFromSendAsync()
public async Task ExceptionOnSendAsyncClosesWithError()
{
var testHttpHandler = new TestHttpMessageHandler();
var longPollTcs = new TaskCompletionSource<HttpResponseMessage>();
var longPollTcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
testHttpHandler.OnLongPoll(cancellationToken => longPollTcs.Task);
testHttpHandler.OnLongPoll(cancellationToken =>
{
cancellationToken.Register(() => longPollTcs.TrySetResult(null));
return longPollTcs.Task;
});
testHttpHandler.OnSocketSend((buf, cancellationToken) =>
{
@ -111,10 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
await connection.StartAsync().OrTimeout();
var exception = await Assert.ThrowsAsync<HttpRequestException>(
async () => await connection.SendAsync(new byte[0]).OrTimeout());
await connection.SendAsync(new byte[] { 0 }).OrTimeout();
longPollTcs.TrySetResult(null);
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => closed.OrTimeout());
});
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -62,20 +63,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await connection.StartAsync().OrTimeout();
// This will trigger the received callback
testTransport.Application.Writer.TryWrite(Array.Empty<byte>());
await testTransport.Application.Output.WriteAsync(new byte[] { 1 });
// Wait to hit the sync point. We are now blocking up the TaskQueue
await onReceived.WaitForSyncPoint().OrTimeout();
// Now we write something else and we want to test that the HttpConnection receive loop is still
// removing items from the channel even though OnReceived is blocked up.
testTransport.Application.Writer.TryWrite(Array.Empty<byte>());
await testTransport.Application.Output.WriteAsync(new byte[] { 1 });
// Now that we've written, we wait for WaitToReadAsync to return an INCOMPLETE task. It will do so
// once HttpConnection reads the message. We also use a CTS to timeout in case the loop is indeed blocked
var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromSeconds(5));
while (testTransport.Application.Reader.WaitToReadAsync().IsCompleted && !cts.IsCancellationRequested)
while (testTransport.Application.Input.WaitToReadAsync().IsCompleted && !cts.IsCancellationRequested)
{
// Yield to allow the HttpConnection to dequeue the message
await Task.Yield();
@ -109,7 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await connection.StartAsync().OrTimeout();
logger.LogInformation("Started connection");
testTransport.Application.Writer.TryWrite(Array.Empty<byte>());
await testTransport.Application.Output.WriteAsync(new byte[] { 1 });
await onReceived.WaitForSyncPoint().OrTimeout();
// Dispose should complete, even though the receive callbacks are completely blocked up.

View File

@ -251,10 +251,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public ProtocolType Type => ProtocolType.Binary;
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, out IList<HubMessage> messages)
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
{
messages = new List<HubMessage>();
ParseCalls += 1;
if (_error != null)
{

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Net;
using System.Net.Http;
using System.Text;
@ -12,7 +13,6 @@ using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client.Tests;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
using Moq;
using Moq.Protected;
using Xunit;
@ -41,10 +41,8 @@ namespace Microsoft.AspNetCore.Client.Tests
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
transportActiveTask = longPollingTransport.Running;
@ -77,13 +75,14 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = ChannelConnection.Create(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
await longPollingTransport.Running.OrTimeout();
Assert.True(transportToConnection.Reader.Completion.IsCompleted);
Assert.True(pair.Transport.Input.TryRead(out var result));
Assert.True(result.IsCompleted);
pair.Transport.Input.AdvanceTo(result.Buffer.End);
}
finally
{
@ -130,17 +129,12 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
var data = await transportToConnection.Reader.ReadAllAsync().OrTimeout();
var data = await pair.Transport.Input.ReadAllAsync().OrTimeout();
await longPollingTransport.Running.OrTimeout();
Assert.True(transportToConnection.Reader.Completion.IsCompleted);
Assert.Equal(2, data.Count);
Assert.Equal(Encoding.UTF8.GetBytes("Hello"), data[0]);
Assert.Equal(Encoding.UTF8.GetBytes("World"), data[1]);
Assert.Equal(Encoding.UTF8.GetBytes("HelloWorld"), data);
}
finally
{
@ -166,13 +160,19 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
var exception =
await Assert.ThrowsAsync<HttpRequestException>(async () => await transportToConnection.Reader.Completion.OrTimeout());
await Assert.ThrowsAsync<HttpRequestException>(async () =>
{
async Task ReadAsync()
{
await pair.Transport.Input.ReadAsync();
}
await ReadAsync().OrTimeout();
});
Assert.Contains(" 500 ", exception.Message);
}
finally
@ -202,21 +202,14 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
await connectionToTransport.Writer.WriteAsync(new SendMessage());
await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World"));
await Assert.ThrowsAsync<HttpRequestException>(async () => await longPollingTransport.Running.OrTimeout());
// The channel needs to be drained for the Completion task to be completed
while (transportToConnection.Reader.TryRead(out var message))
{
}
var exception = await Assert.ThrowsAsync<HttpRequestException>(async () => await transportToConnection.Reader.Completion);
var exception = await Assert.ThrowsAsync<HttpRequestException>(async () => await pair.Transport.Input.ReadAllAsync().OrTimeout());
Assert.Contains(" 500 ", exception.Message);
}
finally
@ -243,17 +236,14 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
connectionToTransport.Writer.Complete();
pair.Transport.Output.Complete();
await longPollingTransport.Running.OrTimeout();
await longPollingTransport.Running.OrTimeout();
await connectionToTransport.Reader.Completion.OrTimeout();
await pair.Transport.Input.ReadAllAsync().OrTimeout();
}
finally
{
@ -292,32 +282,22 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
// Start the transport
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
// Wait for the transport to finish
await longPollingTransport.Running.OrTimeout();
// Pull Messages out of the channel
var messages = new List<byte[]>();
while (await transportToConnection.Reader.WaitToReadAsync())
{
while (transportToConnection.Reader.TryRead(out var message))
{
messages.Add(message);
}
}
var message = await pair.Transport.Input.ReadAllAsync();
// Check the provided request
Assert.Equal(2, sentRequests.Count);
// Check the messages received
Assert.Single(messages);
Assert.Equal(message1Payload, messages[0]);
Assert.Equal(message1Payload, message);
}
finally
{
@ -350,24 +330,19 @@ namespace Microsoft.AspNetCore.Client.Tests
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var tcs1 = new TaskCompletionSource<object>();
var tcs2 = new TaskCompletionSource<object>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
// Pre-queue some messages
await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), tcs1)).OrTimeout();
await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout();
await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("World"));
// Start the transport
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
connectionToTransport.Writer.Complete();
pair.Transport.Output.Complete();
await longPollingTransport.Running.OrTimeout();
await connectionToTransport.Reader.Completion.OrTimeout();
await pair.Transport.Input.ReadAllAsync();
Assert.Single(sentRequests);
Assert.Equal(new byte[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o', (byte)'W', (byte)'o', (byte)'r', (byte)'l', (byte)'d'
@ -400,12 +375,10 @@ namespace Microsoft.AspNetCore.Client.Tests
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
Assert.Null(longPollingTransport.Mode);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connection: new TestConnection());
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, transferMode, connection: new TestConnection());
Assert.Equal(transferMode, longPollingTransport.Mode);
}
finally
@ -466,10 +439,8 @@ namespace Microsoft.AspNetCore.Client.Tests
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection());
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection());
var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout();
Assert.Equal(completionTcs.Task, completedTask);

View File

@ -3,6 +3,7 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
@ -10,10 +11,8 @@ using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Client.Tests;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
using Moq;
using Moq.Protected;
using Xunit;
@ -51,11 +50,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
await eventStreamTcs.Task.OrTimeout();
await sseTransport.StopAsync().OrTimeout();
@ -102,15 +99,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
try
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
transportActiveTask = sseTransport.Running;
Assert.False(transportActiveTask.IsCompleted);
var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout();
var message = await pair.Transport.Input.ReadSingleAsync().OrTimeout();
Assert.Equal("3:abc", Encoding.ASCII.GetString(message));
}
finally
@ -150,11 +146,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
var exception = await Assert.ThrowsAsync<FormatException>(() => sseTransport.Running.OrTimeout());
Assert.Equal("Incomplete message.", exception.Message);
@ -195,18 +189,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
await eventStreamTcs.Task;
var sendTcs = new TaskCompletionSource<object>();
Assert.True(connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)));
await pair.Transport.Output.WriteAsync(new byte[] { 0x42 });
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => sendTcs.Task.OrTimeout());
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => pair.Transport.Input.ReadAllAsync().OrTimeout());
Assert.Contains("500", exception.Message);
Assert.Same(exception, await Assert.ThrowsAsync<HttpRequestException>(() => sseTransport.Running.OrTimeout()));
@ -242,15 +233,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
await eventStreamTcs.Task.OrTimeout();
connectionToTransport.Writer.TryComplete(null);
pair.Transport.Output.Complete();
await sseTransport.Running.OrTimeout();
}
@ -272,13 +261,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout();
await sseTransport.StartAsync(
new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of<IConnection>()).OrTimeout();
var message = await pair.Transport.Input.ReadSingleAsync().OrTimeout();
Assert.Equal("3:abc", Encoding.ASCII.GetString(message));
await sseTransport.Running.OrTimeout();
@ -302,11 +290,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
Assert.Null(sseTransport.Mode);
await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connection: Mock.Of<IConnection>()).OrTimeout();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
await sseTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, transferMode, connection: Mock.Of<IConnection>()).OrTimeout();
Assert.Equal(TransferMode.Text, sseTransport.Mode);
await sseTransport.StopAsync().OrTimeout();
}
@ -327,9 +314,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var sseTransport = new ServerSentEventsTransport(httpClient);
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var exception = await Assert.ThrowsAsync<ArgumentException>(() =>
sseTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connection: Mock.Of<IConnection>()));

View File

@ -1,4 +1,5 @@
using System;
using System.IO.Pipelines;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -12,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
private readonly Func<Task> _startHandler;
public TransferMode? Mode { get; }
public Channel<byte[], SendMessage> Application { get; private set; }
public IDuplexPipe Application { get; private set; }
public TestTransport(Func<Task> onTransportStop = null, Func<Task> onTransportStart = null, TransferMode transferMode = TransferMode.Text)
{
@ -21,7 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
Mode = transferMode;
}
public Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, IConnection connection)
public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection)
{
Application = application;
return _startHandler();
@ -30,7 +31,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public async Task StopAsync()
{
await _stopHandler();
Application.Writer.TryComplete();
Application.Output.Complete();
}
}
}

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Xunit;
@ -22,11 +23,31 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
[MemberData(nameof(Payloads))]
public void VerifyEncode(string payload, string encoded)
{
var encodedMessage = Encoding.UTF8.GetBytes(encoded);
var decodedMessage = Encoding.UTF8.GetString(new Base64Encoder().Decode(encodedMessage).ToArray());
ReadOnlySpan<byte> encodedMessage = Encoding.UTF8.GetBytes(encoded);
var encoder = new Base64Encoder();
encoder.TryDecode(ref encodedMessage, out var data);
var decodedMessage = Encoding.UTF8.GetString(data.ToArray());
Assert.Equal(payload, decodedMessage);
}
[Fact]
public void CanParseMultipleMessages()
{
ReadOnlySpan<byte> data = Encoding.UTF8.GetBytes("28:QQpSDUMNCjtERUYxMjM0NTY3ODkw;4:QUJD;4:QUJD;");
var encoder = new Base64Encoder();
Assert.True(encoder.TryDecode(ref data, out var payload1));
Assert.True(encoder.TryDecode(ref data, out var payload2));
Assert.True(encoder.TryDecode(ref data, out var payload3));
Assert.False(encoder.TryDecode(ref data, out var payload4));
Assert.Equal(0, data.Length);
var payload1Value = Encoding.UTF8.GetString(payload1.ToArray());
var payload2Value = Encoding.UTF8.GetString(payload2.ToArray());
var payload3Value = Encoding.UTF8.GetString(payload3.ToArray());
Assert.Equal("A\nR\rC\r\n;DEF1234567890", payload1Value);
Assert.Equal("ABC", payload2Value);
Assert.Equal("ABC", payload3Value);
}
public static IEnumerable<object[]> Payloads =>
new object[][]
{

View File

@ -125,7 +125,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(expectedMessage);
var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages);
var messages = new List<HubMessage>();
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages);
Assert.Equal(expectedMessage, messages[0], TestHubMessageEqualityComparer.Instance);
}
@ -174,7 +175,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(Array.Empty<Type>(), typeof(object));
var protocol = new JsonHubProtocol();
var ex = Assert.Throws<InvalidDataException>(() => protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages));
var messages = new List<HubMessage>();
var ex = Assert.Throws<InvalidDataException>(() => protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages));
Assert.Equal(expectedMessage, ex.Message);
}
@ -189,7 +191,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
var protocol = new JsonHubProtocol();
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages);
var messages = new List<HubMessage>();
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages);
var ex = Assert.Throws<InvalidDataException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments);
Assert.Equal(expectedMessage, ex.Message);
}

View File

@ -349,9 +349,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
// Parse the input fully now.
bytes = Frame(bytes);
var protocol = new MessagePackHubProtocol();
Assert.True(protocol.TryParseMessages(bytes, new TestBinder(testData.Message), out var messages));
var messages = new List<HubMessage>();
Assert.True(protocol.TryParseMessages(bytes, new TestBinder(testData.Message), messages));
Assert.Equal(1, messages.Count);
Assert.Single(messages);
Assert.Equal(testData.Message, messages[0], TestHubMessageEqualityComparer.Instance);
}
@ -419,7 +420,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
var buffer = Frame(Pack(testData.Encoded));
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var exception = Assert.Throws<FormatException>(() => _hubProtocol.TryParseMessages(buffer, binder, out var messages));
var messages = new List<HubMessage>();
var exception = Assert.Throws<FormatException>(() => _hubProtocol.TryParseMessages(buffer, binder, messages));
Assert.Equal(testData.ErrorMessage, exception.Message);
}
@ -447,7 +449,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
var buffer = Frame(Pack(testData.Encoded));
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
_hubProtocol.TryParseMessages(buffer, binder, out var messages);
var messages = new List<HubMessage>();
_hubProtocol.TryParseMessages(buffer, binder, messages);
var exception = Assert.Throws<FormatException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments);
Assert.Equal(testData.ErrorMessage, exception.Message);
@ -458,7 +461,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount)
{
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var result = _hubProtocol.TryParseMessages(payload, binder, out var messages);
var messages = new List<HubMessage>();
var result = _hubProtocol.TryParseMessages(payload, binder, messages);
Assert.True(result || messages.Count == 0);
Assert.Equal(expectedMessagesCount, messages.Count);
}

View File

@ -0,0 +1,75 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
namespace System.IO.Pipelines
{
public static class PipeReaderExtensions
{
public static async Task<bool> WaitToReadAsync(this PipeReader pipeReader)
{
while (true)
{
var result = await pipeReader.ReadAsync();
try
{
if (!result.Buffer.IsEmpty)
{
return true;
}
if (result.IsCompleted)
{
return false;
}
}
finally
{
// Consume nothing, just wait for everything
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
}
}
}
public static async Task<byte[]> ReadSingleAsync(this PipeReader pipeReader)
{
while (true)
{
var result = await pipeReader.ReadAsync();
try
{
return result.Buffer.ToArray();
}
finally
{
pipeReader.AdvanceTo(result.Buffer.End);
}
}
}
public static async Task<byte[]> ReadAllAsync(this PipeReader pipeReader)
{
while (true)
{
var result = await pipeReader.ReadAsync();
try
{
if (result.IsCompleted)
{
return result.Buffer.ToArray();
}
}
finally
{
// Consume nothing, just wait for everything
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
}
}
}
}
}

View File

@ -260,7 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[MemberData(nameof(MessageSizesData))]
public async Task ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport(string message)
{
using (StartLog(out var loggerFactory, testName: $"ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport_{message.Length}"))
using (StartLog(out var loggerFactory, LogLevel.Trace, testName: $"ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport_{message.Length}"))
{
var logger = loggerFactory.CreateLogger<EndToEndTests>();

View File

@ -2,11 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Channels;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.Logging.Testing;
using Moq;
@ -36,12 +35,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection,
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application,
TransferMode.Binary, connection: Mock.Of<IConnection>()).OrTimeout();
await webSocketsTransport.StopAsync().OrTimeout();
await webSocketsTransport.Running.OrTimeout();
@ -54,14 +50,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection,
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application,
TransferMode.Binary, connection: Mock.Of<IConnection>());
connectionToTransport.Writer.TryComplete();
pair.Transport.Output.Complete();
await webSocketsTransport.Running.OrTimeout(TimeSpan.FromSeconds(10));
}
}
@ -74,33 +67,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connection: Mock.Of<IConnection>());
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, transferMode, connection: Mock.Of<IConnection>());
var sendTcs = new TaskCompletionSource<object>();
connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs));
try
{
await sendTcs.Task;
}
catch (OperationCanceledException)
{
// Because the server and client are run in the same process there is a race where websocket.SendAsync
// can send a message but before returning be suspended allowing the server to run the EchoEndpoint and
// send a close frame which triggers a cancellation token on the client and cancels the websocket.SendAsync.
// Our solution to this is to just catch OperationCanceledException from the sent message if the race happens
// because we know the send went through, and its safe to check the response.
}
await pair.Transport.Output.WriteAsync(new byte[] { 0x42 });
// The echo endpoint closes the connection immediately after sending response which should stop the transport
await webSocketsTransport.Running.OrTimeout();
Assert.True(transportToConnection.Reader.TryRead(out var buffer));
Assert.Equal(new byte[] { 0x42 }, buffer);
Assert.True(pair.Transport.Input.TryRead(out var result));
Assert.Equal(new byte[] { 0x42 }, result.Buffer.ToArray());
pair.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
@ -112,14 +90,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory);
Assert.Null(webSocketsTransport.Mode);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection,
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application,
transferMode, connection: Mock.Of<IConnection>()).OrTimeout();
Assert.Equal(transferMode, webSocketsTransport.Mode);
@ -134,13 +109,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory);
var exception = await Assert.ThrowsAsync<ArgumentException>(() =>
webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connection: Mock.Of<IConnection>()));
webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text | TransferMode.Binary, connection: Mock.Of<IConnection>()));
Assert.Contains("Invalid transfer mode.", exception.Message);
Assert.Equal("requestedTransferMode", exception.ParamName);