From 792745ad98f9c6314406ab9ce87e255ee0ba3e4a Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Mon, 13 Nov 2017 15:05:35 -0800 Subject: [PATCH] React to CoreFxLab packages (#998) --- Directory.Build.props | 5 + .../MessageParserBenchmark.cs | 4 +- build/dependencies.props | 18 +- .../EchoEndPoint.cs | 5 +- .../PersistentConnectionLifeTimeManager.cs | 2 +- .../SocialWeather/SocialWeatherEndPoint.cs | 6 +- .../EndPoints/MessagesEndPoint.cs | 8 +- samples/SocketsSample/Hubs/Streaming.cs | 10 +- .../HubConnection.cs | 6 +- .../HubConnectionExtensions.StreamAsync.cs | 38 +- .../HubConnectionExtensions.cs | 2 +- .../InvocationRequest.cs | 18 +- .../Internal/Encoders/Base64Encoder.cs | 2 +- .../LengthPrefixedTextMessageParser.cs | 10 +- .../Formatters/BinaryMessageFormatter.cs | 2 +- .../Formatters/BinaryMessageParser.cs | 2 +- .../Internal/Formatters/TextMessageParser.cs | 2 +- .../Internal/Protocol/IHubProtocol.cs | 2 +- .../Internal/Protocol/JsonHubProtocol.cs | 2 +- .../Protocol/MessagePackHubProtocol.cs | 2 +- .../Internal/Protocol/NegotiationProtocol.cs | 2 +- ...Microsoft.AspNetCore.SignalR.Common.csproj | 1 - .../HubConnectionContext.cs | 12 +- .../HubEndPoint.cs | 14 +- .../Internal/AsyncEnumeratorAdapters.cs | 82 +++- .../ChannelConnection.cs | 26 +- .../ChannelReaderExtensions.cs | 47 ++ .../ConnectionContext.cs | 2 +- .../Features/IConnectionTransportFeature.cs | 2 +- ...oft.AspNetCore.Sockets.Abstractions.csproj | 3 +- .../HttpConnection.cs | 8 +- .../ITransport.cs | 4 +- .../LongPollingTransport.cs | 10 +- ...soft.AspNetCore.Sockets.Client.Http.csproj | 2 +- .../SendUtils.cs | 8 +- .../ServerSentEventsMessageParser.cs | 2 +- .../ServerSentEventsTransport.cs | 14 +- .../WebSocketsTransport.cs | 18 +- .../HttpConnectionDispatcher.cs | 12 +- .../Transports/LongPollingTransport.cs | 8 +- .../ServerSentEventsMessageFormatter.cs | 2 +- .../Transports/ServerSentEventsTransport.cs | 8 +- .../Transports/WebSocketsTransport.cs | 14 +- .../Microsoft.AspNetCore.Sockets.Http.csproj | 2 +- .../ConnectionManager.cs | 2 +- .../DefaultConnectionContext.cs | 10 +- .../Microsoft.AspNetCore.Sockets.csproj | 2 +- test/Common/ChannelExtensions.cs | 5 +- test/Common/TestClient.cs | 14 +- .../HubConnectionTests.cs | 2 +- .../Hubs.cs | 20 +- .../HttpConnectionTests.cs | 28 +- .../HubConnectionProtocolTests.cs | 2 +- .../HubConnectionTests.cs | 2 +- .../LongPollingTransportTests.cs | 34 +- .../ServerSentEventsParserTests.cs | 16 +- .../ServerSentEventsTransportTests.cs | 19 +- .../TestConnection.cs | 17 +- .../LengthPrefixedTextMessageParserTests.cs | 10 +- .../Formatters/BinaryMessageFormatterTests.cs | 2 +- .../Formatters/BinaryMessageParserTests.cs | 12 +- .../Formatters/TextMessageParserTests.cs | 8 +- .../RedisHubLifetimeManagerTests.cs | 46 +- .../DefaultHubLifetimeManagerTests.cs | 33 +- .../EchoEndPoint.cs | 3 +- .../HubEndpointTests.cs | 20 +- .../DefaultHubProtocolResolverTests.cs | 8 +- .../WebSocketsTransportTests.cs | 10 +- .../ConnectionManagerTests.cs | 6 +- .../HttpConnectionDispatcherTests.cs | 20 +- .../LongPollingTests.cs | 18 +- .../MapEndPointTests.cs | 59 ++- .../Microsoft.AspNetCore.Sockets.Tests.csproj | 3 +- .../ServerSentEventsTests.cs | 18 +- .../TestWebSocketConnectionFeature.cs | 16 +- .../WebSocketsTests.cs | 418 ++++++++++-------- 76 files changed, 766 insertions(+), 566 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelReaderExtensions.cs diff --git a/Directory.Build.props b/Directory.Build.props index b51ed60133..a391978b7d 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -11,6 +11,11 @@ true true true + latest + + + + diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs index a9b99c63e1..0661272907 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs @@ -9,8 +9,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks public class MessageParserBenchmark { private static readonly Random Random = new Random(); - private ReadOnlyBuffer _binaryInput; - private ReadOnlyBuffer _textInput; + private ReadOnlyMemory _binaryInput; + private ReadOnlyMemory _textInput; [Params(32, 64)] public int ChunkSize { get; set; } diff --git a/build/dependencies.props b/build/dependencies.props index 8608bcaf2e..ddbe310bd4 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -50,21 +50,23 @@ 2.1.0-preview1-27475 2.1.0-preview1-27475 2.0.0 + 2.6.0-beta2-62211-02 15.3.0 4.7.49 0.9.0-beta2 10.0.1 1.2.4 - 0.1.0-e170811-6 - 0.1.0-e170811-6 - 0.1.0-e170811-6 - 0.1.0-e170811-6 - 4.4.0-preview3-25519-03 - 4.4.0 + 0.1.0-alpha-002 + 0.1.0-alpha-002 + 0.1.0-alpha-002 + 0.1.0-alpha-002 + 4.5.0-preview1-25902-08 + 4.5.0-preview1-25902-08 3.1.1 4.3.0 - 4.4.0 - 0.1.0-e170811-6 + 4.5.0-preview1-25902-08 + 4.5.0-preview1-25902-08 + 4.4.0 2.3.0 2.3.0 diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs index e08b972bd1..433532b432 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs @@ -1,8 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Test.Server @@ -11,7 +12,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server { public async override Task OnConnectedAsync(ConnectionContext connection) { - await connection.Transport.Out.WriteAsync(await connection.Transport.In.ReadAsync()); + await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync()); } } } diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index 87a6fbb9c3..f0af31544a 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -40,7 +40,7 @@ namespace SocialWeather var ms = new MemoryStream(); await formatter.WriteAsync(data, ms); - connection.Transport.Out.TryWrite(ms.ToArray()); + connection.Transport.Writer.TryWrite(ms.ToArray()); } } diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index e412cfafeb..17889ec1aa 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; @@ -34,9 +34,9 @@ namespace SocialWeather var formatter = _formatterResolver.GetFormatter( (string)connection.Metadata["formatType"]); - while (await connection.Transport.In.WaitToReadAsync()) + while (await connection.Transport.Reader.WaitToReadAsync()) { - if (connection.Transport.In.TryRead(out var buffer)) + if (connection.Transport.Reader.TryRead(out var buffer)) { var stream = new MemoryStream(); await stream.WriteAsync(buffer, 0, buffer.Length); diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index 5559e56518..a17cb4624a 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; @@ -20,9 +20,9 @@ namespace SocketsSample.EndPoints try { - while (await connection.Transport.In.WaitToReadAsync()) + while (await connection.Transport.Reader.WaitToReadAsync()) { - if (connection.Transport.In.TryRead(out var buffer)) + if (connection.Transport.Reader.TryRead(out var buffer)) { // We can avoid the copy here but we'll deal with that later var text = Encoding.UTF8.GetString(buffer); @@ -50,7 +50,7 @@ namespace SocketsSample.EndPoints foreach (var c in Connections) { - tasks.Add(c.Transport.Out.WriteAsync(payload)); + tasks.Add(c.Transport.Writer.WriteAsync(payload)); } return Task.WhenAll(tasks); diff --git a/samples/SocketsSample/Hubs/Streaming.cs b/samples/SocketsSample/Hubs/Streaming.cs index 63fa8b71e1..cee2c42cdf 100644 --- a/samples/SocketsSample/Hubs/Streaming.cs +++ b/samples/SocketsSample/Hubs/Streaming.cs @@ -1,7 +1,7 @@ using System; using System.Reactive.Linq; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR; namespace SocketsSample.Hubs @@ -15,7 +15,7 @@ namespace SocketsSample.Hubs .Take(count); } - public ReadableChannel ChannelCounter(int count, int delay) + public ChannelReader ChannelCounter(int count, int delay) { var channel = Channel.CreateUnbounded(); @@ -23,14 +23,14 @@ namespace SocketsSample.Hubs { for (var i = 0; i < count; i++) { - await channel.Out.WriteAsync(i); + await channel.Writer.WriteAsync(i); await Task.Delay(delay); } - channel.Out.TryComplete(); + channel.Writer.TryComplete(); }); - return channel.In; + return channel.Reader; } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 5294bcd93d..9caa61a92b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -8,7 +8,7 @@ using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; @@ -145,12 +145,12 @@ namespace Microsoft.AspNetCore.SignalR.Client return new Subscription(invocationHandler, invocationList); } - public async Task> StreamAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) + public async Task> StreamAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) { return await StreamAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); } - private async Task> StreamAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) + private async Task> StreamAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { if (!_startCalled) { diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs index b0821d97e9..81eec09974 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs @@ -1,71 +1,71 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Client { public static partial class HubConnectionExtensions { - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, Array.Empty(), cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5, arg6 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, object arg9, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, object arg9, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9 }, cancellationToken); } - public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, object arg9, object arg10, CancellationToken cancellationToken = default) + public static Task> StreamAsync(this HubConnection hubConnection, string methodName, object arg1, object arg2, object arg3, object arg4, object arg5, object arg6, object arg7, object arg8, object arg9, object arg10, CancellationToken cancellationToken = default) { return hubConnection.StreamAsync(methodName, new object[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); } - public static async Task> StreamAsync(this HubConnection hubConnection, string methodName, object[] args, CancellationToken cancellationToken = default) + public static async Task> StreamAsync(this HubConnection hubConnection, string methodName, object[] args, CancellationToken cancellationToken = default) { if (hubConnection == null) { @@ -85,9 +85,9 @@ namespace Microsoft.AspNetCore.SignalR.Client { while (inputChannel.TryRead(out var item)) { - while (!outputChannel.Out.TryWrite((TResult)item)) + while (!outputChannel.Writer.TryWrite((TResult)item)) { - if (!await outputChannel.Out.WaitToWriteAsync()) + if (!await outputChannel.Writer.WaitToWriteAsync()) { // Failed to write to the output channel because it was closed. Nothing really we can do but abort here. return; @@ -101,18 +101,18 @@ namespace Microsoft.AspNetCore.SignalR.Client } catch (Exception ex) { - outputChannel.Out.TryComplete(ex); + outputChannel.Writer.TryComplete(ex); } finally { // This will safely no-op if the catch block above ran. - outputChannel.Out.TryComplete(); + outputChannel.Writer.TryComplete(); } } _ = RunChannel(); - return outputChannel.In; + return outputChannel.Reader; } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs index edf00f1f83..5588a97b4a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs @@ -4,7 +4,7 @@ using System; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Client { diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs index 40afc4cf02..1018c22b98 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs @@ -4,7 +4,7 @@ using System; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, - ILoggerFactory loggerFactory, HubConnection hubConnection, out ReadableChannel result) + ILoggerFactory loggerFactory, HubConnection hubConnection, out ChannelReader result) { var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection); result = req.Result; @@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { } - public ReadableChannel Result => _channel.In; + public ChannelReader Result => _channel.Reader; public override void Complete(CompletionMessage completionMessage) { @@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.SignalR.Client if (completionMessage.Result != null) { Logger.ReceivedUnexpectedComplete(InvocationId); - _channel.Out.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation.")); + _channel.Writer.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation.")); } if (!string.IsNullOrEmpty(completionMessage.Error)) @@ -92,22 +92,22 @@ namespace Microsoft.AspNetCore.SignalR.Client return; } - _channel.Out.TryComplete(); + _channel.Writer.TryComplete(); } public override void Fail(Exception exception) { Logger.InvocationFailed(InvocationId); - _channel.Out.TryComplete(exception); + _channel.Writer.TryComplete(exception); } public override async ValueTask StreamItem(object item) { try { - while (!_channel.Out.TryWrite(item)) + while (!_channel.Writer.TryWrite(item)) { - if (!await _channel.Out.WaitToWriteAsync()) + if (!await _channel.Writer.WaitToWriteAsync()) { return false; } @@ -122,7 +122,7 @@ namespace Microsoft.AspNetCore.SignalR.Client protected override void Cancel() { - _channel.Out.TryComplete(new OperationCanceledException("Invocation terminated")); + _channel.Writer.TryComplete(new OperationCanceledException("Invocation terminated")); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs index f5a8c044b0..881b05112d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs @@ -11,7 +11,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders { public byte[] Decode(byte[] payload) { - var buffer = new ReadOnlyBuffer(payload); + var buffer = new ReadOnlyMemory(payload); LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message); return Convert.FromBase64String(Encoding.UTF8.GetString(message.ToArray())); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/LengthPrefixedTextMessageParser.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/LengthPrefixedTextMessageParser.cs index 25419fbd98..686add4a15 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/LengthPrefixedTextMessageParser.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/LengthPrefixedTextMessageParser.cs @@ -14,20 +14,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders /// Attempts to parse a message from the buffer. Returns 'false' if there is not enough data to complete a message. Throws an /// exception if there is a format error in the provided data. /// - public static bool TryParseMessage(ref ReadOnlyBuffer buffer, out ReadOnlyBuffer payload) + public static bool TryParseMessage(ref ReadOnlyMemory buffer, out ReadOnlyMemory payload) { - payload = default; - var span = buffer.Span; + payload = default(ReadOnlyMemory); - if (!TryReadLength(span, out var index, out var length)) + if (!TryReadLength(buffer.Span, out var index, out var length)) { return false; } var remaining = buffer.Slice(index); - span = remaining.Span; - if (!TryReadDelimiter(span, LengthPrefixedTextMessageWriter.FieldDelimiter, "length")) + if (!TryReadDelimiter(remaining.Span, LengthPrefixedTextMessageWriter.FieldDelimiter, "length")) { return false; } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs index 8eb5936e92..44115592f6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters var buffer = ArrayPool.Shared.Rent(lenNumBytes + payload.Length); var bufferSpan = buffer.AsSpan(); - new Span(lenBuffer, lenNumBytes).CopyTo(bufferSpan); + new ReadOnlySpan(lenBuffer, lenNumBytes).CopyTo(bufferSpan); bufferSpan = bufferSpan.Slice(lenNumBytes); payload.CopyTo(bufferSpan); output.Write(buffer, 0, lenNumBytes + payload.Length); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageParser.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageParser.cs index 1835fa34ae..4889ea33f4 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageParser.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageParser.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters private static int[] _numBitsToShift = new[] { 0, 7, 14, 21, 28 }; private const int MaxLengthPrefixSize = 5; - public static bool TryParseMessage(ref ReadOnlyBuffer buffer, out ReadOnlyBuffer payload) + public static bool TryParseMessage(ref ReadOnlyMemory buffer, out ReadOnlyMemory payload) { payload = default; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs index 1a1c55bef4..fac697290e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs @@ -7,7 +7,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters { public static class TextMessageParser { - public static bool TryParseMessage(ref ReadOnlyBuffer buffer, out ReadOnlyBuffer payload) + public static bool TryParseMessage(ref ReadOnlyMemory buffer, out ReadOnlyMemory payload) { payload = default; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index 4ccb27e0c2..c02cea455b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol ProtocolType Type { get; } - bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages); + bool TryParseMessages(ReadOnlyMemory input, IInvocationBinder binder, out IList messages); void WriteMessage(HubMessage message, Stream output); } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index 9624302aac..4efa8a3e6b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -61,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public ProtocolType Type => ProtocolType.Text; - public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlyMemory input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index 0640137930..30c85781c7 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol _serializationContext = serializationContext; } - public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlyMemory input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs index 7f6abe5b2c..c886a10ec2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } - public static bool TryParseMessage(ReadOnlyBuffer input, out NegotiationMessage negotiationMessage) + public static bool TryParseMessage(ReadOnlyMemory input, out NegotiationMessage negotiationMessage) { if (!TextMessageParser.TryParseMessage(ref input, out var payload)) { diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj index 74d6d7217a..0ab94fe4e0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj @@ -10,7 +10,6 @@ - diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 56f118dd49..61e9b3e0f2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,7 +8,7 @@ using System.Runtime.ExceptionServices; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; @@ -22,12 +22,12 @@ namespace Microsoft.AspNetCore.SignalR { private static Action _abortedCallback = AbortConnection; - private readonly WritableChannel _output; + private readonly ChannelWriter _output; private readonly ConnectionContext _connectionContext; private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); - public HubConnectionContext(WritableChannel output, ConnectionContext connectionContext) + public HubConnectionContext(ChannelWriter output, ConnectionContext connectionContext) { _output = output; _connectionContext = connectionContext; @@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.SignalR private IHubFeature HubFeature => Features.Get(); // Used by the HubEndPoint only - internal ReadableChannel Input => _connectionContext.Transport; + internal ChannelReader Input => _connectionContext.Transport; internal ExceptionDispatchInfo AbortException { get; private set; } @@ -53,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - public virtual WritableChannel Output => _output; + public virtual ChannelWriter Output => _output; // Currently used only for streaming methods internal ConcurrentDictionary ActiveRequestCancellationSources { get; } = new ConcurrentDictionary(); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index bdb32153b0..f6c572ee37 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -9,7 +9,7 @@ using System.Reflection; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Core.Internal; @@ -84,14 +84,14 @@ namespace Microsoft.AspNetCore.SignalR { try { - while (await output.In.WaitToReadAsync()) + while (await output.Reader.WaitToReadAsync()) { - while (output.In.TryRead(out var hubMessage)) + while (output.Reader.TryRead(out var hubMessage)) { var buffer = protocolReaderWriter.WriteMessage(hubMessage); - while (await connection.Transport.Out.WaitToWriteAsync()) + while (await connection.Transport.Writer.WaitToWriteAsync()) { - if (connection.Transport.Out.TryWrite(buffer)) + if (connection.Transport.Writer.TryWrite(buffer)) { break; } @@ -117,7 +117,7 @@ namespace Microsoft.AspNetCore.SignalR await _lifetimeManager.OnDisconnectedAsync(connectionContext); // Nothing should be writing to the HubConnectionContext - output.Out.TryComplete(); + output.Writer.TryComplete(); // This should unwind once we complete the output await writingOutputTask; @@ -461,7 +461,7 @@ namespace Microsoft.AspNetCore.SignalR private static bool IsChannel(Type type, out Type payloadType) { - var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ReadableChannel<>)); + var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>)); if (channelType == null) { payloadType = null; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs index f066e18048..870b06eaf2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs @@ -6,7 +6,7 @@ using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Internal { @@ -21,6 +21,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod); + private static readonly MethodInfo _getAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters) + .GetRuntimeMethods() + .Single(m => m.Name.Equals(nameof(GetAsyncEnumerator)) && m.IsGenericMethod); + public static IAsyncEnumerator FromObservable(object observable, Type observableInterface, CancellationToken cancellationToken) { // TODO: Cache expressions by observable.GetType()? @@ -34,20 +38,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal // TODO: Allow bounding and optimizations? var channel = Channel.CreateUnbounded(); - var subscription = observable.Subscribe(new ChannelObserver(channel.Out, cancellationToken)); + var subscription = observable.Subscribe(new ChannelObserver(channel.Writer, cancellationToken)); // Dispose the subscription when the token is cancelled cancellationToken.Register(state => ((IDisposable)state).Dispose(), subscription); - return channel.In.GetAsyncEnumerator(cancellationToken); + return GetAsyncEnumerator(channel.Reader, cancellationToken); } public static IAsyncEnumerator FromChannel(object readableChannelOfT, Type payloadType, CancellationToken cancellationToken) { - var enumerator = readableChannelOfT - .GetType() - .GetRuntimeMethod("GetAsyncEnumerator", new[] { typeof(CancellationToken) }) - .Invoke(readableChannelOfT, new object[] { cancellationToken }); + var enumerator = _getAsyncEnumeratorMethod + .MakeGenericMethod(payloadType) + .Invoke(null, new object[] { readableChannelOfT, cancellationToken }); if (payloadType.IsValueType) { @@ -68,10 +71,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal private class ChannelObserver : IObserver { - private WritableChannel _output; + private ChannelWriter _output; private CancellationToken _cancellationToken; - public ChannelObserver(WritableChannel output, CancellationToken cancellationToken) + public ChannelObserver(ChannelWriter output, CancellationToken cancellationToken) { _output = output; _cancellationToken = cancellationToken; @@ -125,5 +128,66 @@ namespace Microsoft.AspNetCore.SignalR.Internal public object Current => _input.Current; public Task MoveNextAsync() => _input.MoveNextAsync(); } + + public static IAsyncEnumerator GetAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) + { + return new AsyncEnumerator(channel, cancellationToken); + } + + /// Provides an async enumerator for the data in a channel. + internal class AsyncEnumerator : IAsyncEnumerator + { + /// The channel being enumerated. + private readonly ChannelReader _channel; + /// Cancellation token used to cancel the enumeration. + private readonly CancellationToken _cancellationToken; + /// The current element of the enumeration. + private T _current; + + internal AsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public T Current => _current; + + public Task MoveNextAsync() + { + ValueTask result = _channel.ReadAsync(_cancellationToken); + + if (result.IsCompletedSuccessfully) + { + _current = result.Result; + return Task.FromResult(true); + } + + return result.AsTask().ContinueWith((t, s) => + { + var thisRef = (AsyncEnumerator)s; + if (t.IsFaulted && t.Exception.InnerException is ChannelClosedException cce && cce.InnerException == null) + { + return false; + } + thisRef._current = t.GetAwaiter().GetResult(); + return true; + }, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); + } + } + } + + /// Represents an enumerator accessed asynchronously. + /// Specifies the type of the data enumerated. + internal interface IAsyncEnumerator + { + /// Asynchronously move the enumerator to the next element. + /// + /// A task that returns true if the enumerator was successfully advanced to the next item, + /// or false if no more data was available in the collection. + /// + Task MoveNextAsync(); + + /// Gets the current element being enumerated. + T Current { get; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelConnection.cs index 101d988e5e..5065cb9d3d 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelConnection.cs @@ -1,8 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.Sockets.Internal { @@ -24,20 +24,19 @@ namespace Microsoft.AspNetCore.Sockets.Internal public Channel Input { get; } public Channel Output { get; } - public override ReadableChannel In => Input; - - public override WritableChannel Out => Output; - public ChannelConnection(Channel input, Channel output) { + Reader = input.Reader; Input = input; + + Writer = output.Writer; Output = output; } public void Dispose() { - Input.Out.TryComplete(); - Output.Out.TryComplete(); + Input.Writer.TryComplete(); + Output.Writer.TryComplete(); } } @@ -46,20 +45,19 @@ namespace Microsoft.AspNetCore.Sockets.Internal public Channel Input { get; } public Channel Output { get; } - public override ReadableChannel In => Input; - - public override WritableChannel Out => Output; - public ChannelConnection(Channel input, Channel output) { + Reader = input.Reader; Input = input; + + Writer = output.Writer; Output = output; } public void Dispose() { - Input.Out.TryComplete(); - Output.Out.TryComplete(); + Input.Writer.TryComplete(); + Output.Writer.TryComplete(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelReaderExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelReaderExtensions.cs new file mode 100644 index 0000000000..e437a1a594 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ChannelReaderExtensions.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public static class ChannelReaderExtensions + { + /// Asynchronously reads an item from the channel. + /// The channel + /// A used to cancel the read operation. + /// A that represents the asynchronous read operation. + public static ValueTask ReadAsync(this ChannelReader channel, CancellationToken cancellationToken = default) + { + try + { + return + cancellationToken.IsCancellationRequested + ? new ValueTask(Task.FromCanceled(cancellationToken)) + : channel.TryRead(out T item) + ? new ValueTask(item) + : ReadAsyncCore(cancellationToken); + } + catch (Exception e) + { + return new ValueTask(Task.FromException(e)); + } + + async ValueTask ReadAsyncCore(CancellationToken ct) + { + while (await channel.WaitToReadAsync(ct).ConfigureAwait(false)) + { + if (channel.TryRead(out T item)) + { + return item; + } + } + + throw new ChannelClosedException(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs index 7fa7278130..8f4c799a16 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Sockets diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs index d29718fa31..e851b49bcc 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -1,7 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.Sockets.Features { diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj index c2dba911f3..56b275ea06 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj @@ -7,7 +7,8 @@ - + + diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index c2c27f5fe2..1a4776ec45 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,7 +8,7 @@ using System.IO; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; @@ -39,8 +39,8 @@ namespace Microsoft.AspNetCore.Sockets.Client private readonly ITransportFactory _transportFactory; private string _connectionId; private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); - private ReadableChannel Input => _transportChannel.In; - private WritableChannel Output => _transportChannel.Out; + private ChannelReader Input => _transportChannel.Input; + private ChannelWriter Output => _transportChannel.Output; private readonly List _callbacks = new List(); private readonly TransportType _requestedTransportType = TransportType.All; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs index 7591183e33..784400db85 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs @@ -1,9 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.Sockets.Client { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index 9229219a52..a059a81230 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -6,7 +6,7 @@ using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; @@ -59,7 +59,7 @@ namespace Microsoft.AspNetCore.Sockets.Client Running = Task.WhenAll(_sender, _poller).ContinueWith(t => { _logger.TransportStopped(_connectionId, t.Exception?.InnerException); - _application.Out.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); return t; }).Unwrap(); @@ -123,9 +123,9 @@ namespace Microsoft.AspNetCore.Sockets.Client var payload = await response.Content.ReadAsByteArrayAsync(); if (payload.Length > 0) { - while (!_application.Out.TryWrite(payload)) + while (!_application.Writer.TryWrite(payload)) { - if (cancellationToken.IsCancellationRequested || !await _application.Out.WaitToWriteAsync(cancellationToken)) + if (cancellationToken.IsCancellationRequested || !await _application.Writer.WaitToWriteAsync(cancellationToken)) { return; } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj index d09d5db9a5..3a4980df1f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Microsoft.AspNetCore.Sockets.Client.Http.csproj @@ -22,7 +22,7 @@ - + diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index a95c59a353..af05b12060 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -7,7 +7,7 @@ using System.IO; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; @@ -23,11 +23,11 @@ namespace Microsoft.AspNetCore.Sockets.Client IList messages = null; try { - while (await application.In.WaitToReadAsync(transportCts.Token)) + while (await application.Reader.WaitToReadAsync(transportCts.Token)) { // Grab as many messages as we can from the channel messages = new List(); - while (!transportCts.IsCancellationRequested && application.In.TryRead(out SendMessage message)) + while (!transportCts.IsCancellationRequested && application.Reader.TryRead(out SendMessage message)) { messages.Add(message); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsMessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsMessageParser.cs index bc4713320d..4ad011babc 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsMessageParser.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsMessageParser.cs @@ -146,7 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private Span ConvertBufferToSpan(ReadableBuffer buffer) + private ReadOnlySpan ConvertBufferToSpan(ReadableBuffer buffer) { if (buffer.IsSingleSpan) { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index 9018659ddd..62e92d12c2 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -1,13 +1,14 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.IO.Pipelines; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.Logging; @@ -17,6 +18,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class ServerSentEventsTransport : ITransport { + private static readonly MemoryPool _memoryPool = new MemoryPool(); private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); @@ -64,7 +66,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.TransportStopped(_connectionId, t.Exception?.InnerException); - _application.Out.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); return t; }).Unwrap(); @@ -80,7 +82,7 @@ namespace Microsoft.AspNetCore.Sockets.Client var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); var stream = await response.Content.ReadAsStreamAsync(); - var pipelineReader = stream.AsPipelineReader(cancellationToken); + var pipelineReader = StreamPipeConnection.CreateReader(new PipeOptions(_memoryPool), stream); var readCancellationRegistration = cancellationToken.Register( reader => ((IPipeReader)reader).CancelPendingRead(), pipelineReader); try @@ -105,7 +107,7 @@ namespace Microsoft.AspNetCore.Sockets.Client switch (parseResult) { case ServerSentEventsMessageParser.ParseResult.Completed: - _application.Out.TryWrite(buffer); + _application.Writer.TryWrite(buffer); _parser.Reset(); break; case ServerSentEventsMessageParser.ParseResult.Incomplete: @@ -139,7 +141,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.TransportStopping(_connectionId); _transportCts.Cancel(); - _application.Out.TryComplete(); + _application.Writer.TryComplete(); try { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index 2249c44bc9..5b8874b6c5 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -7,7 +7,7 @@ using System.Diagnostics; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -70,8 +70,8 @@ namespace Microsoft.AspNetCore.Sockets.Client { _webSocket.Dispose(); _logger.TransportStopped(_connectionId, t.Exception?.InnerException); - _application.Out.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); - return t; + _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + return t; }).Unwrap(); } @@ -97,7 +97,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.WebSocketClosed(_connectionId, receiveResult.CloseStatus); - _application.Out.Complete( + _application.Writer.Complete( receiveResult.CloseStatus == WebSocketCloseStatus.NormalClosure ? null : new InvalidOperationException( @@ -135,9 +135,9 @@ namespace Microsoft.AspNetCore.Sockets.Client if (!_transportCts.Token.IsCancellationRequested) { _logger.MessageToApp(_connectionId, messageBuffer.Length); - while (await _application.Out.WaitToWriteAsync(_transportCts.Token)) + while (await _application.Writer.WaitToWriteAsync(_transportCts.Token)) { - if (_application.Out.TryWrite(messageBuffer)) + if (_application.Writer.TryWrite(messageBuffer)) { incomingMessage.Clear(); break; @@ -173,9 +173,9 @@ namespace Microsoft.AspNetCore.Sockets.Client try { - while (await _application.In.WaitToReadAsync(_transportCts.Token)) + while (await _application.Reader.WaitToReadAsync(_transportCts.Token)) { - while (_application.In.TryRead(out SendMessage message)) + while (_application.Reader.TryRead(out SendMessage message)) { try { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 1bcca138ef..996605a456 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -93,7 +93,7 @@ namespace Microsoft.AspNetCore.Sockets connection.TransportCapabilities = TransferMode.Text; // We only need to provide the Input channel since writing to the application is handled through /send. - var sse = new ServerSentEventsTransport(connection.Application.In, connection.ConnectionId, _loggerFactory); + var sse = new ServerSentEventsTransport(connection.Application.Reader, connection.ConnectionId, _loggerFactory); await DoPersistentConnection(socketDelegate, sse, context, connection); } @@ -194,7 +194,7 @@ namespace Microsoft.AspNetCore.Sockets context.Response.RegisterForDispose(timeoutSource); context.Response.RegisterForDispose(tokenSource); - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.In, connection.ConnectionId, _loggerFactory); + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Reader, connection.ConnectionId, _loggerFactory); // Start the transport connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); @@ -215,7 +215,7 @@ namespace Microsoft.AspNetCore.Sockets if (resultTask == connection.ApplicationTask) { // Complete the transport (notifying it of the application error if there is one) - connection.Transport.Out.TryComplete(connection.ApplicationTask.Exception); + connection.Transport.Writer.TryComplete(connection.ApplicationTask.Exception); // Wait for the transport to run await connection.TransportTask; @@ -408,9 +408,9 @@ namespace Microsoft.AspNetCore.Sockets } _logger.ReceivedBytes(connection.ConnectionId, buffer.Length); - while (!connection.Application.Out.TryWrite(buffer)) + while (!connection.Application.Writer.TryWrite(buffer)) { - if (!await connection.Application.Out.WaitToWriteAsync()) + if (!await connection.Application.Writer.WaitToWriteAsync()) { return; } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs index 033efc6fe2..a9ae71af2a 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs @@ -1,11 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -13,12 +13,12 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { public class LongPollingTransport : IHttpTransport { - private readonly ReadableChannel _application; + private readonly ChannelReader _application; private readonly ILogger _logger; private readonly CancellationToken _timeoutToken; private readonly string _connectionId; - public LongPollingTransport(CancellationToken timeoutToken, ReadableChannel application, string connectionId, ILoggerFactory loggerFactory) + public LongPollingTransport(CancellationToken timeoutToken, ChannelReader application, string connectionId, ILoggerFactory loggerFactory) { _timeoutToken = timeoutToken; _application = application; diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsMessageFormatter.cs index 5a04cc76ea..21b079c263 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsMessageFormatter.cs @@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters if (nextSliceStart >= payload.Length) { - payload = Span.Empty; + payload = ReadOnlySpan.Empty; } else { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs index 3ff1c1b756..19bda2e390 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs @@ -1,11 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.IO; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Internal.Formatters; @@ -15,11 +15,11 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { public class ServerSentEventsTransport : IHttpTransport { - private readonly ReadableChannel _application; + private readonly ChannelReader _application; private readonly string _connectionId; private readonly ILogger _logger; - public ServerSentEventsTransport(ReadableChannel application, string connectionId, ILoggerFactory loggerFactory) + public ServerSentEventsTransport(ChannelReader application, string connectionId, ILoggerFactory loggerFactory) { _application = application; _connectionId = connectionId; diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index 5adc0ab245..97756e3613 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -7,7 +7,7 @@ using System.Diagnostics; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -87,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports } // We're done writing - _application.Out.TryComplete(); + _application.Writer.TryComplete(); await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); @@ -160,9 +160,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports } _logger.MessageToApplication(_connection.ConnectionId, messageBuffer.Length); - while (await _application.Out.WaitToWriteAsync()) + while (await _application.Writer.WaitToWriteAsync()) { - if (_application.Out.TryWrite(messageBuffer)) + if (_application.Writer.TryWrite(messageBuffer)) { incomingMessage.Clear(); break; @@ -173,10 +173,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports private async Task StartSending(WebSocket ws) { - while (await _application.In.WaitToReadAsync()) + while (await _application.Reader.WaitToReadAsync()) { // Get a frame from the application - while (_application.In.TryRead(out var buffer)) + while (_application.Reader.TryRead(out var buffer)) { if (buffer.Length > 0) { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index 1f7ff58d7a..7cbc3f60f8 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -16,7 +16,7 @@ - + diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 0d23bc7161..9baa558045 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -9,7 +9,7 @@ using System.IO; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index 3acb6bb0c1..990827767b 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -6,7 +6,7 @@ using System.Collections.Generic; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Features; @@ -86,21 +86,21 @@ namespace Microsoft.AspNetCore.Sockets // If the application task is faulted, propagate the error to the transport if (ApplicationTask?.IsFaulted == true) { - Transport.Out.TryComplete(ApplicationTask.Exception.InnerException); + Transport.Writer.TryComplete(ApplicationTask.Exception.InnerException); } else { - Transport.Out.TryComplete(); + Transport.Writer.TryComplete(); } // If the transport task is faulted, propagate the error to the application if (TransportTask?.IsFaulted == true) { - Application.Out.TryComplete(TransportTask.Exception.InnerException); + Application.Writer.TryComplete(TransportTask.Exception.InnerException); } else { - Application.Out.TryComplete(); + Application.Writer.TryComplete(); } var applicationTask = ApplicationTask ?? Task.CompletedTask; diff --git a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj index 0de655a731..27dd5e9574 100644 --- a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj +++ b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj @@ -13,7 +13,7 @@ - + diff --git a/test/Common/ChannelExtensions.cs b/test/Common/ChannelExtensions.cs index 2502886317..fd03225379 100644 --- a/test/Common/ChannelExtensions.cs +++ b/test/Common/ChannelExtensions.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; +using System.Threading.Tasks; -namespace System.Threading.Tasks.Channels +namespace System.Threading.Channels { internal static class ChannelExtensions { - public static async Task> ReadAllAsync(this ReadableChannel channel) + public static async Task> ReadAllAsync(this ChannelReader channel) { var list = new List(); while (await channel.WaitToReadAsync()) diff --git a/test/Common/TestClient.cs b/test/Common/TestClient.cs index a4d903120f..42a7f95ba3 100644 --- a/test/Common/TestClient.cs +++ b/test/Common/TestClient.cs @@ -7,7 +7,7 @@ using System.IO; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { - var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks }; + var options = new UnboundedChannelOptions { AllowSynchronousContinuations = synchronousCallbacks }; var transportToApplication = Channel.CreateUnbounded(options); var applicationToTransport = Channel.CreateUnbounded(options); @@ -60,7 +60,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var memoryStream = new MemoryStream()) { NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream); - Application.Out.TryWrite(memoryStream.ToArray()); + Application.Writer.TryWrite(memoryStream.ToArray()); } } @@ -149,7 +149,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task SendHubMessageAsync(HubMessage message) { var payload = _protocolReaderWriter.WriteMessage(message); - await Application.Out.WriteAsync(payload); + await Application.Writer.WriteAsync(payload); return message.InvocationId; } @@ -161,7 +161,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests if (message == null) { - if (!await Application.In.WaitToReadAsync()) + if (!await Application.Reader.WaitToReadAsync()) { return null; } @@ -175,7 +175,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public HubMessage TryRead() { - if (Application.In.TryRead(out var buffer) && + if (Application.Reader.TryRead(out var buffer) && _protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages)) { return messages[0]; @@ -208,4 +208,4 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 8e3c2bfc5e..e110cc26a0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 02c2a87535..37b3e44dda 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -1,11 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Linq; using System.Reactive.Linq; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { @@ -17,9 +17,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); - public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); + public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); - public ReadableChannel StreamBroken() => TestHubMethodsImpl.StreamBroken(); + public ChannelReader StreamBroken() => TestHubMethodsImpl.StreamBroken(); public async Task CallEcho(string message) { @@ -40,9 +40,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); - public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); + public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); - public ReadableChannel StreamBroken() => TestHubMethodsImpl.StreamBroken(); + public ChannelReader StreamBroken() => TestHubMethodsImpl.StreamBroken(); public async Task CallEcho(string message) { @@ -63,9 +63,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); - public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); + public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); - public ReadableChannel StreamBroken() => TestHubMethodsImpl.StreamBroken(); + public ChannelReader StreamBroken() => TestHubMethodsImpl.StreamBroken(); public async Task CallEcho(string message) { @@ -97,12 +97,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests .Take(count); } - public static ReadableChannel StreamException() + public static ChannelReader StreamException() { throw new InvalidOperationException("Error occurred while streaming."); } - public static ReadableChannel StreamBroken() => null; + public static ChannelReader StreamBroken() => null; } public interface ITestHub diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 58a20ef28b..5a613c7a70 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -7,7 +7,7 @@ using System.Net.Http; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Features; @@ -268,8 +268,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests { // The connection is now in the Disconnected state so the Received event for // this message should not be raised - channel.Out.TryWrite(Array.Empty()); - channel.Out.TryComplete(); + channel.Writer.TryWrite(Array.Empty()); + channel.Writer.TryComplete(); return Task.CompletedTask; }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); @@ -313,7 +313,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests mockTransport.Setup(t => t.StopAsync()) .Returns(() => { - channel.Out.TryComplete(); + channel.Writer.TryComplete(); return Task.CompletedTask; }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); @@ -330,14 +330,14 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests }); await connection.StartAsync(); - channel.Out.TryWrite(Array.Empty()); + channel.Writer.TryWrite(Array.Empty()); // Ensure that the Received callback has been called before attempting the second write await callbackInvokedTcs.Task.OrTimeout(); - channel.Out.TryWrite(Array.Empty()); + channel.Writer.TryWrite(Array.Empty()); // Ensure that SignalR isn't blocked by the receive callback - Assert.False(channel.In.TryRead(out var message)); + Assert.False(channel.Reader.TryRead(out var message)); closedTcs.SetResult(null); @@ -369,7 +369,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests mockTransport.Setup(t => t.StopAsync()) .Returns(() => { - channel.Out.TryComplete(); + channel.Writer.TryComplete(); return Task.CompletedTask; }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); @@ -380,10 +380,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests connection.OnReceived(_ => blockReceiveCallbackTcs.Task); await connection.StartAsync(); - channel.Out.TryWrite(Array.Empty()); + channel.Writer.TryWrite(Array.Empty()); // Ensure that SignalR isn't blocked by the receive callback - Assert.False(channel.In.TryRead(out var message)); + Assert.False(channel.Reader.TryRead(out var message)); await connection.DisposeAsync(); } @@ -413,7 +413,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests mockTransport.Setup(t => t.StopAsync()) .Returns(() => { - channel.Out.TryComplete(); + channel.Writer.TryComplete(); return Task.CompletedTask; }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); @@ -427,10 +427,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests }); await connection.StartAsync(); - channel.Out.TryWrite(Array.Empty()); + channel.Writer.TryWrite(Array.Empty()); // Ensure that SignalR isn't blocked by the receive callback - Assert.False(channel.In.TryRead(out var message)); + Assert.False(channel.Reader.TryRead(out var message)); await connection.DisposeAsync(); } @@ -909,7 +909,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests mockTransport.Setup(t => t.StopAsync()) .Returns(() => { - channel.Out.TryComplete(); + channel.Writer.TryComplete(); return Task.CompletedTask; }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Binary); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index 6aec848eb0..1c6e121f53 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -6,7 +6,7 @@ using System.Globalization; using System.IO; using System.Text; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 1114955fc0..13da4d90ec 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -215,7 +215,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public ProtocolType Type => ProtocolType.Binary; - public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlyMemory input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 1892508f13..5f14541613 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,7 +8,7 @@ using System.Net.Http; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Client.Tests await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); await longPollingTransport.Running.OrTimeout(); - Assert.True(transportToConnection.In.Completion.IsCompleted); + Assert.True(transportToConnection.Reader.Completion.IsCompleted); } finally { @@ -135,9 +135,9 @@ namespace Microsoft.AspNetCore.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); - var data = await transportToConnection.In.ReadAllAsync().OrTimeout(); + var data = await transportToConnection.Reader.ReadAllAsync().OrTimeout(); await longPollingTransport.Running.OrTimeout(); - Assert.True(transportToConnection.In.Completion.IsCompleted); + Assert.True(transportToConnection.Reader.Completion.IsCompleted); Assert.Equal(2, data.Count); Assert.Equal(Encoding.UTF8.GetBytes("Hello"), data[0]); Assert.Equal(Encoding.UTF8.GetBytes("World"), data[1]); @@ -172,7 +172,7 @@ namespace Microsoft.AspNetCore.Client.Tests await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); var exception = - await Assert.ThrowsAsync(async () => await transportToConnection.In.Completion.OrTimeout()); + await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion.OrTimeout()); Assert.Contains(" 500 ", exception.Message); } finally @@ -207,16 +207,16 @@ namespace Microsoft.AspNetCore.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); - await connectionToTransport.Out.WriteAsync(new SendMessage()); + await connectionToTransport.Writer.WriteAsync(new SendMessage()); await Assert.ThrowsAsync(async () => await longPollingTransport.Running.OrTimeout()); // The channel needs to be drained for the Completion task to be completed - while (transportToConnection.In.TryRead(out var message)) + while (transportToConnection.Reader.TryRead(out var message)) { } - var exception = await Assert.ThrowsAsync(async () => await transportToConnection.In.Completion); + var exception = await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion); Assert.Contains(" 500 ", exception.Message); } finally @@ -248,12 +248,12 @@ namespace Microsoft.AspNetCore.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); - connectionToTransport.Out.Complete(); + connectionToTransport.Writer.Complete(); await longPollingTransport.Running.OrTimeout(); await longPollingTransport.Running.OrTimeout(); - await connectionToTransport.In.Completion.OrTimeout(); + await connectionToTransport.Reader.Completion.OrTimeout(); } finally { @@ -304,9 +304,9 @@ namespace Microsoft.AspNetCore.Client.Tests // Pull Messages out of the channel var messages = new List(); - while (await transportToConnection.In.WaitToReadAsync()) + while (await transportToConnection.Reader.WaitToReadAsync()) { - while (transportToConnection.In.TryRead(out var message)) + while (transportToConnection.Reader.TryRead(out var message)) { messages.Add(message); } @@ -358,16 +358,16 @@ namespace Microsoft.AspNetCore.Client.Tests var tcs2 = new TaskCompletionSource(); // Pre-queue some messages - await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), tcs1)).OrTimeout(); - await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout(); + await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), tcs1)).OrTimeout(); + await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout(); // Start the transport await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); - connectionToTransport.Out.Complete(); + connectionToTransport.Writer.Complete(); await longPollingTransport.Running.OrTimeout(); - await connectionToTransport.In.Completion.OrTimeout(); + await connectionToTransport.Reader.Completion.OrTimeout(); Assert.Single(sentRequests); Assert.Equal(new byte[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o', (byte)'W', (byte)'o', (byte)'r', (byte)'l', (byte)'d' diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsParserTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsParserTests.cs index ffe4031546..4e9d1bfae4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsParserTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsParserTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO.Pipelines; using System.Text; @@ -106,10 +107,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData(new[] { "data: Hello, World\r\n", ":comment\r\n", "\r\n" }, "Hello, World")] public async Task ParseMessageAcrossMultipleReadsSuccess(string[] messageParts, string expectedMessage) { - using (var pipeFactory = new PipeFactory()) + var parser = new ServerSentEventsMessageParser(); + using (var pool = new MemoryPool()) { - var parser = new ServerSentEventsMessageParser(); - var pipe = pipeFactory.Create(); + var pipe = new Pipe(new PipeOptions(pool)); byte[] message = null; ReadCursor consumed = default, examined = default; @@ -152,9 +153,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\n\n", "There was an error in the frame format")] public async Task ParseMessageAcrossMultipleReadsFailure(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage) { - using (var pipeFactory = new PipeFactory()) + using (var pool = new MemoryPool()) { - var pipe = pipeFactory.Create(); + var pipe = new Pipe(new PipeOptions(pool)); // Read the first part of the message await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes(encodedMessagePart1)); @@ -173,7 +174,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var ex = Assert.Throws(() => parser.ParseMessage(result.Buffer, out consumed, out examined, out buffer)); Assert.Equal(expectedMessage, ex.Message); - } } @@ -181,9 +181,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData("data: foo\r\n\r\n", "data: bar\r\n\r\n")] public async Task ParseMultipleMessagesText(string message1, string message2) { - using (var pipeFactory = new PipeFactory()) + using (var pool = new MemoryPool()) { - var pipe = pipeFactory.Create(); + var pipe = new Pipe(new PipeOptions(pool)); // Read the first part of the message await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes(message1 + message2)); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index ed5a2d0054..622996bef3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,8 +8,9 @@ using System.Net.Http.Headers; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Client.Tests; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -42,6 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests mockStream .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(copyToAsyncTcs.Task); + mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); @@ -83,12 +85,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(async (stream, bufferSize, t) => { + await Task.Yield(); var buffer = Encoding.ASCII.GetBytes("data: 3:abc\r\n\r\n"); while (!eventStreamCts.IsCancellationRequested) { await stream.WriteAsync(buffer, 0, buffer.Length); } }); + mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); @@ -109,7 +113,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests transportActiveTask = sseTransport.Running; Assert.False(transportActiveTask.IsCompleted); - var message = await transportToConnection.In.ReadAsync().AsTask().OrTimeout(); + var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); } finally @@ -140,6 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var buffer = Encoding.ASCII.GetBytes("data: 3:a"); await stream.WriteAsync(buffer, 0, buffer.Length); }); + mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); @@ -182,6 +187,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests mockStream .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(copyToAsyncTcs.Task); + mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; } @@ -201,7 +207,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await eventStreamTcs.Task; var sendTcs = new TaskCompletionSource(); - Assert.True(connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs))); + Assert.True(connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs))); var exception = await Assert.ThrowsAsync(() => sendTcs.Task.OrTimeout()); Assert.Contains("500", exception.Message); @@ -231,6 +237,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests mockStream .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(copyToAsyncTcs.Task); + mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); @@ -246,7 +253,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); - connectionToTransport.Out.TryComplete(null); + connectionToTransport.Writer.TryComplete(null); await sseTransport.Running.OrTimeout(); } @@ -274,7 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await sseTransport.StartAsync( new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); - var message = await transportToConnection.In.ReadAsync().AsTask().OrTimeout(); + var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); await sseTransport.Running.OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index d234931fd8..6e8c8be30d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -7,8 +7,9 @@ using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -34,8 +35,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public Task Closed => _closeTcs.Task; public Task Started => _started.Task; public Task Disposed => _disposed.Task; - public ReadableChannel SentMessages => _sentMessages.In; - public WritableChannel ReceivedMessages => _receivedMessages.Out; + public ChannelReader SentMessages => _sentMessages.Reader; + public ChannelWriter ReceivedMessages => _receivedMessages.Writer; private readonly List _callbacks = new List(); @@ -61,9 +62,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests throw new InvalidOperationException("Connection must be started before SendAsync can be called"); } - while (await _sentMessages.Out.WaitToWriteAsync(cancellationToken)) + while (await _sentMessages.Writer.WaitToWriteAsync(cancellationToken)) { - if (_sentMessages.Out.TryWrite(data)) + if (_sentMessages.Writer.TryWrite(data)) { return; } @@ -100,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var json = JsonConvert.SerializeObject(jsonObject, Formatting.None); var bytes = FormatMessageToArray(Encoding.UTF8.GetBytes(json)); - return _receivedMessages.Out.WriteAsync(bytes); + return _receivedMessages.Writer.WriteAsync(bytes); } private byte[] FormatMessageToArray(byte[] message) @@ -116,9 +117,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { while (!token.IsCancellationRequested) { - while (await _receivedMessages.In.WaitToReadAsync(token)) + while (await _receivedMessages.Reader.WaitToReadAsync(token)) { - while (_receivedMessages.In.TryRead(out var message)) + while (_receivedMessages.Reader.TryRead(out var message)) { ReceiveCallback[] callbackCopies; lock (_callbacks) diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/LengthPrefixedTextMessageParserTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/LengthPrefixedTextMessageParserTests.cs index 78e296ce29..923e80dd66 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/LengthPrefixedTextMessageParserTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/LengthPrefixedTextMessageParserTests.cs @@ -18,7 +18,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Encoders [InlineData("12:Hello, World;", "Hello, World")] public void ReadTextMessage(string encoded, string payload) { - ReadOnlyBuffer buffer = Encoding.UTF8.GetBytes(encoded); + ReadOnlyMemory buffer = Encoding.UTF8.GetBytes(encoded); Assert.True(LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message)); Assert.Equal(0, buffer.Length); @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Encoders public void ReadMultipleMessages() { const string encoded = "0:;14:Hello,\r\nWorld!;"; - ReadOnlyBuffer buffer = Encoding.UTF8.GetBytes(encoded); + ReadOnlyMemory buffer = Encoding.UTF8.GetBytes(encoded); var messages = new List(); while (LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message)) @@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Encoders [InlineData("5:ABCDE")] public void ReadIncompleteMessages(string encoded) { - ReadOnlyBuffer buffer = Encoding.UTF8.GetBytes(encoded); + ReadOnlyMemory buffer = Encoding.UTF8.GetBytes(encoded); Assert.False(LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out _)); } @@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Encoders [InlineData("5:ABCDEF", "Missing delimiter ';' after payload")] public void ReadInvalidMessages(string encoded, string expectedMessage) { - ReadOnlyBuffer buffer = Encoding.UTF8.GetBytes(encoded); + ReadOnlyMemory buffer = Encoding.UTF8.GetBytes(encoded); var ex = Assert.Throws(() => { LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out _); @@ -79,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Encoders { // Invalid because first character is a UTF-8 "continuation" character // We need to include the ':' so that - ReadOnlyBuffer buffer = new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F, (byte)':' }; + ReadOnlyMemory buffer = new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F, (byte)':' }; var ex = Assert.Throws(() => { LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out _); diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs index 7a35a6e16e..b727e8124d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs @@ -109,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters using (var ms = new MemoryStream()) { BinaryMessageFormatter.WriteMessage(payload, ms); - var buffer = new ReadOnlyBuffer(ms.ToArray()); + var buffer = new ReadOnlyMemory(ms.ToArray()); Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped)); Assert.Equal(payload, roundtripped.ToArray()); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs index 58dfae9c99..7a42eca843 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")] public void ReadMessage(byte[] encoded, string payload) { - ReadOnlyBuffer span = encoded; + ReadOnlyMemory span = encoded; Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message)); Assert.Equal(0, span.Length); @@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters })] public void ReadBinaryMessage(byte[] encoded, byte[] payload) { - ReadOnlyBuffer span = encoded; + ReadOnlyMemory span = encoded; Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message)); Assert.Equal(0, span.Length); Assert.Equal(payload, message.ToArray()); @@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [InlineData(new byte[] { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF })] public void BinaryMessageParserThrowsForMessagesOver2GB(byte[] payload) { - var buffer = new ReadOnlyBuffer(payload); + var buffer = new ReadOnlyMemory(payload); var ex = Assert.Throws(() => BinaryMessageParser.TryParseMessage(ref buffer, out var message)); Assert.Equal("Messages over 2GB in size are not supported.", ex.Message); } @@ -76,7 +76,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [InlineData(new byte[] { 0x80 })] // size is cut public void BinaryMessageParserReturnsFalseForPartialPayloads(byte[] payload) { - var buffer = new ReadOnlyBuffer(payload); + var buffer = new ReadOnlyMemory(payload); Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message)); } @@ -90,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters /* length: */ 0x0E, /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, }; - ReadOnlyBuffer buffer = encoded; + ReadOnlyMemory buffer = encoded; var messages = new List(); while (BinaryMessageParser.TryParseMessage(ref buffer, out var message)) @@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [InlineData(new byte[] { 0x09, 0x00, 0x00 })] // Not enough data for payload public void ReadIncompleteMessages(byte[] encoded) { - ReadOnlyBuffer buffer = encoded; + ReadOnlyMemory buffer = encoded; Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message)); Assert.Equal(encoded.Length, buffer.Length); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageParserTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageParserTests.cs index e6d953cf87..9dbc7b2866 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageParserTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageParserTests.cs @@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [Fact] public void ReadMessage() { - var message = new ReadOnlyBuffer(Encoding.UTF8.GetBytes("ABC\u001e")); + var message = new ReadOnlyMemory(Encoding.UTF8.GetBytes("ABC\u001e")); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); @@ -23,14 +23,14 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [Fact] public void TryReadingIncompleteMessage() { - var message = new ReadOnlyBuffer(Encoding.UTF8.GetBytes("ABC")); + var message = new ReadOnlyMemory(Encoding.UTF8.GetBytes("ABC")); Assert.False(TextMessageParser.TryParseMessage(ref message, out var payload)); } [Fact] public void TryReadingMultipleMessages() { - var message = new ReadOnlyBuffer(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e")); + var message = new ReadOnlyMemory(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e")); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); Assert.True(TextMessageParser.TryParseMessage(ref message, out payload)); @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [Fact] public void IncompleteTrailingMessage() { - var message = new ReadOnlyBuffer(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e123")); + var message = new ReadOnlyMemory(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e123")); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); Assert.True(TextMessageParser.TryParseMessage(ref message, out payload)); diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index e761309382..7dcc10a447 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -4,7 +4,7 @@ using System; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; @@ -70,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests AssertMessage(output1); - Assert.False(output2.In.TryRead(out var item)); + Assert.False(output2.Reader.TryRead(out var item)); } } @@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests AssertMessage(output1); - Assert.False(output2.In.TryRead(out var item)); + Assert.False(output2.Reader.TryRead(out var item)); } } @@ -201,7 +201,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests AssertMessage(output1); - Assert.False(output2.In.TryRead(out var item)); + Assert.False(output2.Reader.TryRead(out var item)); } } @@ -286,7 +286,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - Assert.False(output.In.TryRead(out var item)); + Assert.False(output.Reader.TryRead(out var item)); } } @@ -387,7 +387,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); AssertMessage(output); - Assert.False(output.In.TryRead(out var item)); + Assert.False(output.Reader.TryRead(out var item)); } } @@ -417,7 +417,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); AssertMessage(output); - Assert.False(output.In.TryRead(out var item)); + Assert.False(output.Reader.TryRead(out var item)); } } @@ -451,7 +451,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - Assert.False(output.In.TryRead(out var item)); + Assert.False(output.Reader.TryRead(out var item)); } } @@ -480,7 +480,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); AssertMessage(output); - Assert.False(output.In.TryRead(out var item)); + Assert.False(output.Reader.TryRead(out var item)); } } @@ -499,10 +499,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var output = new Mock>(); - output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); + var writer = new Mock>(); + writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); - var connection = new HubConnectionContext(output.Object, client.Connection); + var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); await manager2.OnConnectedAsync(connection).OrTimeout(); @@ -523,10 +523,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var output = new Mock>(); - output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); + var writer = new Mock>(); + writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); - var connection = new HubConnectionContext(output.Object, client.Connection); + var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -549,10 +549,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests var output2 = Channel.CreateUnbounded(); // Force an exception when writing to connection - var output = new Mock>(); - output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); + var writer = new Mock>(); + writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); - var connection1 = new HubConnectionContext(output.Object, client1.Connection); + var connection1 = new HubConnectionContext(new MockChannel(writer.Object), client1.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); @@ -573,7 +573,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests private void AssertMessage(Channel channel) { - Assert.True(channel.In.TryRead(out var item)); + Assert.True(channel.Reader.TryRead(out var item)); var message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -583,5 +583,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests private class MyHub : Hub { } + + private class MockChannel : Channel + { + public MockChannel(ChannelWriter writer = null) + { + Writer = writer; + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index 81dd798b2a..7065eadd9d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -1,7 +1,7 @@ -using System; +using System; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Moq; @@ -29,13 +29,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.In.TryRead(out var item)); + Assert.True(output1.Reader.TryRead(out var item)); var message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.True(output2.In.TryRead(out item)); + Assert.True(output2.Reader.TryRead(out item)); message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -63,13 +63,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.In.TryRead(out var item)); + Assert.True(output1.Reader.TryRead(out var item)); var message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.False(output2.In.TryRead(out item)); + Assert.False(output2.Reader.TryRead(out item)); } } @@ -93,13 +93,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.In.TryRead(out var item)); + Assert.True(output1.Reader.TryRead(out var item)); var message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.False(output2.In.TryRead(out item)); + Assert.False(output2.Reader.TryRead(out item)); } } @@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output.In.TryRead(out var item)); + Assert.True(output.Reader.TryRead(out var item)); var message = Assert.IsType(item); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); @@ -130,11 +130,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // Force an exception when writing to connection - var output = new Mock>(); - output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); + var writer = new Mock>(); + writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); var manager = new DefaultHubLifetimeManager(); - var connection = new HubConnectionContext(output.Object, client.Connection); + var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -168,5 +168,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests { } + + private class MockChannel: Channel + { + + public MockChannel(ChannelWriter writer = null) + { + Writer = writer; + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs index 08c0a5d5b3..034230b956 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Tests @@ -10,7 +11,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public async override Task OnConnectedAsync(ConnectionContext connection) { - await connection.Transport.Out.WriteAsync(await connection.Transport.In.ReadAsync()); + await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync()); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index bcd31c4fbf..a9c806748f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -8,7 +8,7 @@ using System.Runtime.Serialization; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.SignalR.Internal; @@ -259,7 +259,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent - client.Connection.Transport.In.TryRead(out var item); + client.Connection.Transport.Reader.TryRead(out var item); var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -285,7 +285,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent - client.Connection.Transport.In.TryRead(out var item); + client.Connection.Transport.Reader.TryRead(out var item); await endPoint.OnConnectedAsync(client.Connection).OrTimeout(); } @@ -521,7 +521,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await client.SendInvocationAsync(methodName, nonBlocking: true).OrTimeout(); // Nothing should have been written - Assert.False(client.Application.In.TryRead(out var buffer)); + Assert.False(client.Application.Reader.TryRead(out var buffer)); // kill the connection client.Dispose(); @@ -1595,7 +1595,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests return new CountingObservable(count); } - public ReadableChannel CounterChannel(int count) + public ChannelReader CounterChannel(int count) { var channel = Channel.CreateUnbounded(); @@ -1603,17 +1603,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests { for (int i = 0; i < count; i++) { - await channel.Out.WriteAsync(i.ToString()); + await channel.Writer.WriteAsync(i.ToString()); } - channel.Out.Complete(); + channel.Writer.Complete(); }); - return channel.In; + return channel.Reader; } - public ReadableChannel BlockingStream() + public ChannelReader BlockingStream() { - return Channel.CreateUnbounded().In; + return Channel.CreateUnbounded().Reader; } private class CountingObservable : IObservable diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs index b0596994cd..dba1123dd3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -1,9 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [MemberData(nameof(HubProtocols))] public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) { - var mockConnection = new Mock(Channel.CreateUnbounded().Out, new Mock().Object); + var mockConnection = new Mock(Channel.CreateUnbounded().Writer, new Mock().Object); Assert.IsType( protocol.GetType(), new DefaultHubProtocolResolver(Options.Create(new HubOptions())).GetProtocol(protocol.Name, mockConnection.Object)); @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests [InlineData("dummy")] public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName) { - var mockConnection = new Mock(Channel.CreateUnbounded().Out, new Mock().Object); + var mockConnection = new Mock(Channel.CreateUnbounded().Writer, new Mock().Object); var exception = Assert.Throws( () => new DefaultHubProtocolResolver(Options.Create(new HubOptions())).GetProtocol(protocolName, mockConnection.Object)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index b46f1757de..edb9099fdb 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -1,9 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; @@ -61,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var webSocketsTransport = new WebSocketsTransport(loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, TransferMode.Binary, connectionId: string.Empty); - connectionToTransport.Out.TryComplete(); + connectionToTransport.Writer.TryComplete(); await webSocketsTransport.Running.OrTimeout(TimeSpan.FromSeconds(10)); } } @@ -82,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connectionId: string.Empty); var sendTcs = new TaskCompletionSource(); - connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); + connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); try { await sendTcs.Task; @@ -99,7 +99,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests // The echo endpoint closes the connection immediately after sending response which should stop the transport await webSocketsTransport.Running.OrTimeout(); - Assert.True(transportToConnection.In.TryRead(out var buffer)); + Assert.True(transportToConnection.Reader.TryRead(out var buffer)); Assert.Equal(new byte[] { 0x42 }, buffer); } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 9f0c56a1a1..f2cb45f02c 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -82,12 +82,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests connection.ApplicationTask = Task.Run(async () => { - Assert.False(await connection.Transport.In.WaitToReadAsync()); + Assert.False(await connection.Transport.Reader.WaitToReadAsync()); }); connection.TransportTask = Task.Run(async () => { - Assert.False(await connection.Application.In.WaitToReadAsync()); + Assert.False(await connection.Application.Reader.WaitToReadAsync()); }); connectionManager.CloseConnections(); @@ -197,7 +197,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests appLifetime.StopApplication(); // Connection should be disposed so this should complete immediately - Assert.False(await connection.Application.Out.WaitToWriteAsync().OrTimeout()); + Assert.False(await connection.Application.Writer.WaitToWriteAsync().OrTimeout()); } private static ConnectionManager CreateConnectionManager(IApplicationLifetime lifetime = null) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index b71067b857..c0605de722 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -511,7 +511,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the transport so the poll yields - await connection.Transport.Out.WriteAsync(buffer); + await connection.Transport.Writer.WriteAsync(buffer); await task; @@ -543,7 +543,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await connection.Application.Out.WriteAsync(buffer); + await connection.Application.Writer.WriteAsync(buffer); await task; @@ -573,7 +573,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await connection.Application.Out.WriteAsync(buffer); + await connection.Application.Writer.WriteAsync(buffer); await task; @@ -606,7 +606,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await task1.OrTimeout(); // Send a message from the app to complete Task 2 - await connection.Transport.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")); + await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")); await task2.OrTimeout(); @@ -775,7 +775,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -853,7 +853,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -907,7 +907,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -1110,7 +1110,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override Task OnConnectedAsync(ConnectionContext connection) { - connection.Transport.In.WaitToReadAsync().Wait(); + connection.Transport.Reader.WaitToReadAsync().Wait(); return Task.CompletedTask; } } @@ -1135,7 +1135,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override async Task OnConnectedAsync(ConnectionContext connection) { - while (await connection.Transport.In.WaitToReadAsync()) + while (await connection.Transport.Reader.WaitToReadAsync()) { } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index c851dfe713..112314f6ac 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -1,11 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal.Transports; @@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var poll = new LongPollingTransport(CancellationToken.None, channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await poll.ProcessRequestAsync(context, context.RequestAborted); @@ -56,9 +56,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); + await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await poll.ProcessRequestAsync(context, context.RequestAborted); @@ -76,11 +76,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello")); - await channel.Out.WriteAsync(Encoding.UTF8.GetBytes(" ")); - await channel.Out.WriteAsync(Encoding.UTF8.GetBytes("World")); + await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes(" ")); + await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("World")); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await poll.ProcessRequestAsync(context, context.RequestAborted); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs index 79a0bd9d34..da1ae0c0fa 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -13,12 +13,21 @@ using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.Sockets.Tests { public class MapEndPointTests { + private ITestOutputHelper _output; + + public MapEndPointTests(ITestOutputHelper output) + { + _output = output; + } + [Fact] public void MapEndPointFindsAuthAttributeOnEndPoint() { @@ -40,6 +49,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); }); }) + .ConfigureLogging(factory => + { + factory.AddXunit(_output, LogLevel.Trace); + }) .Build(); Assert.Equal(1, authCount); @@ -66,6 +79,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); }); }) + .ConfigureLogging(factory => + { + factory.AddXunit(_output, LogLevel.Trace); + }) .Build(); Assert.Equal(1, authCount); @@ -92,6 +109,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); }); }) + .ConfigureLogging(factory => + { + factory.AddXunit(_output, LogLevel.Trace); + }) .Build(); Assert.Equal(2, authCount); @@ -102,24 +123,28 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task MapEndPointWithWebSocketSubProtocolSetsProtocol() { var host = new WebHostBuilder() - .UseUrls("http://127.0.0.1:0") - .UseKestrel() - .ConfigureServices(services => + .UseUrls("http://127.0.0.1:0") + .UseKestrel() + .ConfigureServices(services => + { + services.AddSockets(); + services.AddEndPoint(); + }) + .Configure(app => + { + app.UseSockets(routes => { - services.AddSockets(); - services.AddEndPoint(); - }) - .Configure(app => - { - app.UseSockets(routes => + routes.MapEndPoint("socket", httpSocketOptions => { - routes.MapEndPoint("socket", httpSocketOptions => - { - httpSocketOptions.WebSockets.SubProtocol = "protocol1"; - }); + httpSocketOptions.WebSockets.SubProtocol = "protocol1"; }); - }) - .Build(); + }); + }) + .ConfigureLogging(factory => + { + factory.AddXunit(_output, LogLevel.Trace); + }) + .Build(); await host.StartAsync(); @@ -140,7 +165,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override async Task OnConnectedAsync(ConnectionContext connection) { - while (!await connection.Transport.In.WaitToReadAsync()) + while (!await connection.Transport.Reader.WaitToReadAsync()) { } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj index c73a09bcc4..432dfec514 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj @@ -3,7 +3,7 @@ netcoreapp2.0;net461 netcoreapp2.0 - + win7-x64 @@ -21,6 +21,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index fbdc4137b8..b7283440b4 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -1,10 +1,10 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; using System.Text; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Tests.Common; @@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Features.Set(feature); var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSEWritesMessages() { - var channel = Channel.CreateUnbounded(new ChannelOptimizations + var channel = Channel.CreateUnbounded(new UnboundedChannelOptions { AllowSynchronousContinuations = true }); @@ -62,11 +62,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests var task = sse.ProcessRequestAsync(context, context.RequestAborted); - await channel.Out.WriteAsync(Encoding.ASCII.GetBytes("Hello")); + await channel.Writer.WriteAsync(Encoding.ASCII.GetBytes("Hello")); Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); - channel.Out.TryComplete(); + channel.Writer.TryComplete(); await task.OrTimeout(); } @@ -83,9 +83,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Out.WriteAsync(Encoding.UTF8.GetBytes(message)); + await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes(message)); - Assert.True(channel.Out.TryComplete()); + Assert.True(channel.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs index 29b5cac70f..ea085b939f 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs @@ -1,9 +1,9 @@ -using System; +using System; using System.Collections.Generic; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -22,8 +22,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var clientToServer = Channel.CreateUnbounded(); var serverToClient = Channel.CreateUnbounded(); - var clientSocket = new WebSocketChannel(serverToClient.In, clientToServer.Out); - var serverSocket = new WebSocketChannel(clientToServer.In, serverToClient.Out); + var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer); + var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer); Client = clientSocket; return Task.FromResult(serverSocket); @@ -35,14 +35,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class WebSocketChannel : WebSocket { - private readonly ReadableChannel _input; - private readonly WritableChannel _output; + private readonly ChannelReader _input; + private readonly ChannelWriter _output; private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private WebSocketState _state; - public WebSocketChannel(ReadableChannel input, WritableChannel output) + public WebSocketChannel(ChannelReader input, ChannelWriter output) { _input = input; _output = output; @@ -209,4 +209,4 @@ namespace Microsoft.AspNetCore.Sockets.Tests public string CloseStatusDescription { get; set; } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index ba696ff46e..d93d653e9b 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -6,58 +6,68 @@ using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; +using System.Threading.Channels; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Transports; -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.Sockets.Tests { - public class WebSocketsTests + public class WebSocketsTests : LoggedTest { + public WebSocketsTests(ITestOutputHelper output) + : base(output) + { + } + [Theory] [InlineData(WebSocketMessageType.Text)] [InlineData(WebSocketMessageType.Binary)] public async Task ReceivedFramesAreWrittenToChannel(WebSocketMessageType webSocketMessageType) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); - // Send a frame, then close - await feature.Client.SendAsync( - buffer: new ArraySegment(Encoding.UTF8.GetBytes("Hello")), - messageType: webSocketMessageType, - endOfMessage: true, - cancellationToken: CancellationToken.None); - await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - var buffer = await applicationSide.In.ReadAsync(); - Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + // Send a frame, then close + await feature.Client.SendAsync( + buffer: new ArraySegment(Encoding.UTF8.GetBytes("Hello")), + messageType: webSocketMessageType, + endOfMessage: true, + cancellationToken: CancellationToken.None); + await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - Assert.True(applicationSide.Out.TryComplete()); + var buffer = await applicationSide.Reader.ReadAsync(); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); - // The transport should finish now - await transport; + Assert.True(applicationSide.Writer.TryComplete()); - // The connection should close after this, which means the client will get a close frame. - var clientSummary = await client; + // The transport should finish now + await transport; - Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.CloseStatus); + // The connection should close after this, which means the client will get a close frame. + var clientSummary = await client; + + Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.CloseStatus); + } } } @@ -66,256 +76,276 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(TransferMode.Binary, WebSocketMessageType.Binary)] public async Task WebSocketTransportSetsMessageTypeBasedOnTransferModeFeature(TransferMode transferMode, WebSocketMessageType expectedMessageType) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode }; - var ws = new WebSocketsTransport(new WebSocketOptions(), - transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode }; + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); - // Write to the output channel, and then complete it - await applicationSide.Out.WriteAsync(Encoding.UTF8.GetBytes("Hello")); - Assert.True(applicationSide.Out.TryComplete()); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - // The client should finish now, as should the server - var clientSummary = await client; - await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - await transport; + // Write to the output channel, and then complete it + await applicationSide.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + Assert.True(applicationSide.Writer.TryComplete()); - Assert.Equal(1, clientSummary.Received.Count); - Assert.True(clientSummary.Received[0].EndOfMessage); - Assert.Equal(expectedMessageType, clientSummary.Received[0].MessageType); - Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer)); + // The client should finish now, as should the server + var clientSummary = await client; + await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + await transport; + + Assert.Equal(1, clientSummary.Received.Count); + Assert.True(clientSummary.Received[0].EndOfMessage); + Assert.Equal(expectedMessageType, clientSummary.Received[0].MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer)); + } } } [Fact] public async Task TransportFailsWhenClientDisconnectsAbnormally() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - async Task CompleteApplicationAfterTransportCompletes() + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) { - // Wait until the transport completes so that we can end the application - await applicationSide.In.WaitToReadAsync(); + async Task CompleteApplicationAfterTransportCompletes() + { + // Wait until the transport completes so that we can end the application + await applicationSide.Reader.WaitToReadAsync(); - // Complete the application so that the connection unwinds without aborting - applicationSide.Out.TryComplete(); + // Complete the application so that the connection unwinds without aborting + applicationSide.Writer.TryComplete(); + } + + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + // When the close frame is received, we complete the application so the send + // loop unwinds + _ = CompleteApplicationAfterTransportCompletes(); + + // Terminate the client to server channel with an exception + feature.Client.SendAbort(); + + // Wait for the transport + await Assert.ThrowsAsync(() => transport).OrTimeout(); + + var summary = await client.OrTimeout(); + Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus); } - - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); - - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); - - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); - - // When the close frame is received, we complete the application so the send - // loop unwinds - _ = CompleteApplicationAfterTransportCompletes(); - - // Terminate the client to server channel with an exception - feature.Client.SendAbort(); - - // Wait for the transport - await Assert.ThrowsAsync(() => transport).OrTimeout(); - - var summary = await client.OrTimeout(); - Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus); } } [Fact] public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); - // Fail in the app - Assert.True(applicationSide.Out.TryComplete(new InvalidOperationException("Catastrophic failure."))); - var clientSummary = await client.OrTimeout(); - Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - // Close from the client - await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + // Fail in the app + Assert.True(applicationSide.Writer.TryComplete(new InvalidOperationException("Catastrophic failure."))); + var clientSummary = await client.OrTimeout(); + Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus); - var ex = await Assert.ThrowsAsync(() => transport.OrTimeout()); - Assert.Equal("Catastrophic failure.", ex.Message); + // Close from the client + await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + + var ex = await Assert.ThrowsAsync(() => transport.OrTimeout()); + Assert.Equal("Catastrophic failure.", ex.Message); + } } } [Fact] public async Task TransportClosesOnCloseTimeoutIfClientDoesNotSendCloseFrame() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var options = new WebSocketOptions() + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) { - CloseTimeout = TimeSpan.FromSeconds(1) - }; + var options = new WebSocketOptions() + { + CloseTimeout = TimeSpan.FromSeconds(1) + }; - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); - var serverSocket = await feature.AcceptAsync(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(serverSocket); + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); - // End the app - applicationSide.Dispose(); + // End the app + applicationSide.Dispose(); - await transport.OrTimeout(TimeSpan.FromSeconds(10)); + await transport.OrTimeout(TimeSpan.FromSeconds(10)); - // Now we're closed - Assert.Equal(WebSocketState.Aborted, serverSocket.State); + // Now we're closed + Assert.Equal(WebSocketState.Aborted, serverSocket.State); - serverSocket.Dispose(); + serverSocket.Dispose(); + } } } [Fact] public async Task TransportFailsOnTimeoutWithErrorWhenApplicationFailsAndClientDoesNotSendCloseFrame() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var options = new WebSocketOptions + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) { - CloseTimeout = TimeSpan.FromSeconds(1) - }; + var options = new WebSocketOptions + { + CloseTimeout = TimeSpan.FromSeconds(1) + }; - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); - var serverSocket = await feature.AcceptAsync(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(serverSocket); + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - // fail the client to server channel - applicationToTransport.Out.TryComplete(new Exception()); + // fail the client to server channel + applicationToTransport.Writer.TryComplete(new Exception()); - await Assert.ThrowsAsync(() => transport).OrTimeout(); + await Assert.ThrowsAsync(() => transport).OrTimeout(); - Assert.Equal(WebSocketState.Aborted, serverSocket.State); + Assert.Equal(WebSocketState.Aborted, serverSocket.State); + } } } [Fact] public async Task ServerGracefullyClosesWhenApplicationEndsThenClientSendsCloseFrame() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var options = new WebSocketOptions + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) { - // We want to verify behavior without timeout affecting it - CloseTimeout = TimeSpan.FromSeconds(20) - }; + var options = new WebSocketOptions + { + // We want to verify behavior without timeout affecting it + CloseTimeout = TimeSpan.FromSeconds(20) + }; - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); - var serverSocket = await feature.AcceptAsync(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(serverSocket); + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - // close the client to server channel - applicationToTransport.Out.TryComplete(); + // close the client to server channel + applicationToTransport.Writer.TryComplete(); - _ = await client.OrTimeout(); + _ = await client.OrTimeout(); - await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); - await transport.OrTimeout(); + await transport.OrTimeout(); - Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } } } [Fact] public async Task ServerGracefullyClosesWhenClientSendsCloseFrameThenApplicationEnds() { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) - using (var feature = new TestWebSocketConnectionFeature()) + using (StartLog(out var loggerFactory)) { - var options = new WebSocketOptions + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) + using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) { - // We want to verify behavior without timeout affecting it - CloseTimeout = TimeSpan.FromSeconds(20) - }; - var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); + var options = new WebSocketOptions + { + // We want to verify behavior without timeout affecting it + CloseTimeout = TimeSpan.FromSeconds(20) + }; + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); - var serverSocket = await feature.AcceptAsync(); - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(serverSocket); + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); - await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); - // close the client to server channel - applicationToTransport.Out.TryComplete(); + // close the client to server channel + applicationToTransport.Writer.TryComplete(); - _ = await client.OrTimeout(); + _ = await client.OrTimeout(); - await transport.OrTimeout(); + await transport.OrTimeout(); - Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } } } }