From 28439d14417412b0797f2643cc47b0fceedca0d3 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Fri, 9 Feb 2018 17:45:21 -0800 Subject: [PATCH] Initial changes to move to pipelines (#1424) - Change the Sockets abstraction from Channel to pipelines. #615 --- .../BroadcastBenchmark.cs | 11 +-- build/dependencies.props | 1 + client-ts/FunctionalTests/EchoEndPoint.cs | 16 ++- .../PersistentConnectionLifeTimeManager.cs | 3 +- .../SocialWeather/SocialWeatherEndPoint.cs | 28 ++++-- .../EndPoints/MessagesEndPoint.cs | 29 ++++-- .../Properties/launchSettings.json | 2 +- .../Internal/HubProtocolReaderWriter.cs | 11 +++ .../Internal/Protocol/NegotiationProtocol.cs | 24 +++++ ...Microsoft.AspNetCore.SignalR.Common.csproj | 2 + .../HubConnectionContext.cs | 77 +++++++++------ .../HubEndPoint.cs | 92 ++++++++++-------- .../ConnectionContext.cs | 4 +- .../DuplexPipe.cs | 47 +++++++++ .../Features/IConnectionTransportFeature.cs | 4 +- ...oft.AspNetCore.Sockets.Abstractions.csproj | 3 + .../HttpConnectionDispatcher.cs | 15 +-- .../Internal/SocketHttpLoggerExtensions.cs | 24 ++--- .../Transports/LongPollingTransport.cs | 41 ++++---- .../Transports/ServerSentEventsTransport.cs | 39 +++++--- .../Transports/WebSocketsTransport.cs | 87 +++++++---------- .../ConnectionManager.cs | 11 +-- .../DefaultConnectionContext.cs | 29 ++++-- .../TestClient.cs | 77 +++++++++++---- .../EchoEndPoint.cs | 17 +++- .../EndToEndTests.cs | 3 + .../HubEndpointTests.cs | 18 ++-- .../ConnectionManagerTests.cs | 30 +++++- .../HttpConnectionDispatcherTests.cs | 85 ++++++++++++++-- .../LongPollingTests.cs | 46 ++++----- .../MapEndPointTests.cs | 10 +- .../ServerSentEventsTests.cs | 47 ++++----- .../WebSocketsTests.cs | 97 ++++++++----------- 33 files changed, 658 insertions(+), 372 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/DuplexPipe.cs diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 11213fa5c4..a2121063e3 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -1,11 +1,10 @@ using System; +using System.IO.Pipelines; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using BenchmarkDotNet.Attributes; -using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks @@ -26,12 +25,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks for (var i = 0; i < Connections; ++i) { - var transportToApplication = Channel.CreateUnbounded(options); - var applicationToTransport = Channel.CreateUnbounded(options); - - var application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); - var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport, application); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); _hubLifetimeManager.OnConnectedAsync(new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance)).Wait(); } diff --git a/build/dependencies.props b/build/dependencies.props index 71ea360d22..eea44544e8 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -60,6 +60,7 @@ 10.0.1 1.2.4 4.5.0-preview2-26130-01 + 0.1.0-preview2-180130-1 0.1.0-preview2-180130-1 4.5.0-preview2-26130-01 4.5.0-preview2-26130-01 diff --git a/client-ts/FunctionalTests/EchoEndPoint.cs b/client-ts/FunctionalTests/EchoEndPoint.cs index cd6442c01f..0680acc8ca 100644 --- a/client-ts/FunctionalTests/EchoEndPoint.cs +++ b/client-ts/FunctionalTests/EchoEndPoint.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -10,7 +11,20 @@ namespace FunctionalTests { public async override Task OnConnectedAsync(ConnectionContext connection) { - await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync()); + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + + try + { + if (!buffer.IsEmpty) + { + await connection.Transport.Output.WriteAsync(buffer.ToArray()); + } + } + finally + { + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } } } } diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index e94f5045f4..01093214a6 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; @@ -50,7 +51,7 @@ namespace SocialWeather var ms = new MemoryStream(); await formatter.WriteAsync(data, ms); - connection.Transport.Writer.TryWrite(ms.ToArray()); + await connection.Transport.Output.WriteAsync(ms.ToArray()); } } diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index 9b6ee0a547..84dd03d919 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -34,15 +34,29 @@ namespace SocialWeather var formatter = _formatterResolver.GetFormatter( (string)connection.Metadata["format"]); - while (await connection.Transport.Reader.WaitToReadAsync()) + while (true) { - if (connection.Transport.Reader.TryRead(out var buffer)) + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + try { - var stream = new MemoryStream(); - await stream.WriteAsync(buffer, 0, buffer.Length); - stream.Position = 0; - var weatherReport = await formatter.ReadAsync(stream); - await _lifetimeManager.SendToAllAsync(weatherReport); + if (!buffer.IsEmpty) + { + var stream = new MemoryStream(); + var data = buffer.ToArray(); + await stream.WriteAsync(data, 0, data.Length); + stream.Position = 0; + var weatherReport = await formatter.ReadAsync(stream); + await _lifetimeManager.SendToAllAsync(weatherReport); + } + else if (result.IsCompleted) + { + break; + } + } + finally + { + connection.Transport.Input.AdvanceTo(buffer.End); } } } diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index a17cb4624a..0fd837f8bf 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; +using System.IO.Pipelines; using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -20,14 +21,28 @@ namespace SocketsSample.EndPoints try { - while (await connection.Transport.Reader.WaitToReadAsync()) + while (true) { - if (connection.Transport.Reader.TryRead(out var buffer)) + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + + try { - // We can avoid the copy here but we'll deal with that later - var text = Encoding.UTF8.GetString(buffer); - text = $"{connection.ConnectionId}: {text}"; - await Broadcast(Encoding.UTF8.GetBytes(text)); + if (!buffer.IsEmpty) + { + // We can avoid the copy here but we'll deal with that later + var text = Encoding.UTF8.GetString(buffer.ToArray()); + text = $"{connection.ConnectionId}: {text}"; + await Broadcast(Encoding.UTF8.GetBytes(text)); + } + else if (result.IsCompleted) + { + break; + } + } + finally + { + connection.Transport.Input.AdvanceTo(buffer.End); } } } @@ -50,7 +65,7 @@ namespace SocketsSample.EndPoints foreach (var c in Connections) { - tasks.Add(c.Transport.Writer.WriteAsync(payload)); + tasks.Add(c.Transport.Output.WriteAsync(payload)); } return Task.WhenAll(tasks); diff --git a/samples/SocketsSample/Properties/launchSettings.json b/samples/SocketsSample/Properties/launchSettings.json index e8e1209314..9bad98807c 100644 --- a/samples/SocketsSample/Properties/launchSettings.json +++ b/samples/SocketsSample/Properties/launchSettings.json @@ -3,7 +3,7 @@ "windowsAuthentication": false, "anonymousAuthentication": true, "iisExpress": { - "applicationUrl": "http://localhost:57707/", + "applicationUrl": "http://localhost:59847/", "sslPort": 0 } }, diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs index a684e5e40e..36283170db 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Buffers; +using System.Collections; using System.Collections.Generic; using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Encoders; @@ -19,6 +21,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal _dataEncoder = dataEncoder; } + public bool ReadMessages(ReadOnlyBuffer buffer, IInvocationBinder binder, out IList messages, out SequencePosition consumed, out SequencePosition examined) + { + // TODO: Fix this implementation to be incremental + consumed = buffer.End; + examined = consumed; + + return ReadMessages(buffer.ToArray(), binder, out messages); + } + public bool ReadMessages(byte[] input, IInvocationBinder binder, out IList messages) { var buffer = _dataEncoder.Decode(input); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs index 11a2ee81fe..3994c1bd0a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/NegotiationProtocol.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; +using System.Collections; using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Newtonsoft.Json; @@ -53,5 +55,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } return true; } + + public static bool TryParseMessage(ReadOnlyBuffer buffer, out NegotiationMessage negotiationMessage, out SequencePosition consumed, out SequencePosition examined) + { + var separator = buffer.PositionOf(TextMessageFormatter.RecordSeparator); + if (separator == null) + { + // Haven't seen the entire negotiate message so bail + consumed = buffer.Start; + examined = buffer.End; + negotiationMessage = null; + return false; + } + else + { + consumed = buffer.GetPosition(separator.Value, 1); + examined = consumed; + } + + var memory = buffer.IsSingleSegment ? buffer.First : buffer.ToArray(); + + return TryParseMessage(memory.Span, out negotiationMessage); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj index 0ee14f9b98..1420cd5b9f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Common/Microsoft.AspNetCore.SignalR.Common.csproj @@ -11,6 +11,8 @@ + + diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 0acb43c7f7..0ddc54dbc1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; +using System.IO.Pipelines; using System.Net; using System.Runtime.ExceptionServices; using System.Security.Claims; @@ -15,6 +16,7 @@ using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; +using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; @@ -36,7 +38,6 @@ namespace Microsoft.AspNetCore.SignalR private Task _writingTask = Task.CompletedTask; private long _lastSendTimestamp = Stopwatch.GetTimestamp(); - private byte[] _pingMessage; public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { @@ -59,7 +60,7 @@ namespace Microsoft.AspNetCore.SignalR public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - public virtual ChannelReader Input => _connectionContext.Transport.Reader; + public virtual PipeReader Input => _connectionContext.Transport.Input; public string UserIdentifier { get; private set; } @@ -125,37 +126,54 @@ namespace Microsoft.AspNetCore.SignalR using (var cts = new CancellationTokenSource()) { cts.CancelAfter(timeout); - while (await _connectionContext.Transport.Reader.WaitToReadAsync(cts.Token)) + + while (true) { - while (_connectionContext.Transport.Reader.TryRead(out var buffer)) + var result = await _connectionContext.Transport.Input.ReadAsync(cts.Token); + var buffer = result.Buffer; + var consumed = buffer.End; + var examined = buffer.End; + + try { - if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) + if (!buffer.IsEmpty) { - var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this); + if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage, out consumed, out examined)) + { + var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this); - var transportCapabilities = Features.Get()?.TransportCapabilities - ?? throw new InvalidOperationException("Unable to read transport capabilities."); + var transportCapabilities = Features.Get()?.TransportCapabilities + ?? throw new InvalidOperationException("Unable to read transport capabilities."); - var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) - ? (IDataEncoder)Base64Encoder - : PassThroughEncoder; + var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) + ? (IDataEncoder)Base64Encoder + : PassThroughEncoder; - var transferModeFeature = Features.Get() ?? - throw new InvalidOperationException("Unable to read transfer mode."); + var transferModeFeature = Features.Get() ?? + throw new InvalidOperationException("Unable to read transfer mode."); - transferModeFeature.TransferMode = - (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) - ? TransferMode.Binary - : TransferMode.Text; + transferModeFeature.TransferMode = + (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) + ? TransferMode.Binary + : TransferMode.Text; - ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); + ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); - _logger.UsingHubProtocol(protocol.Name); + _logger.UsingHubProtocol(protocol.Name); - UserIdentifier = userIdProvider.GetUserId(this); + UserIdentifier = userIdProvider.GetUserId(this); - return true; + return true; + } } + else if (result.IsCompleted) + { + break; + } + } + finally + { + _connectionContext.Transport.Input.AdvanceTo(consumed, examined); } } } @@ -186,7 +204,6 @@ namespace Microsoft.AspNetCore.SignalR if (Features.Get() == null) { Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called"); - _pingMessage = ProtocolReaderWriter.WriteMessage(PingMessage.Instance); _connectionContext.Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); } @@ -197,14 +214,10 @@ namespace Microsoft.AspNetCore.SignalR while (Output.Reader.TryRead(out var hubMessage)) { var buffer = ProtocolReaderWriter.WriteMessage(hubMessage); - while (await _connectionContext.Transport.Writer.WaitToWriteAsync()) - { - if (_connectionContext.Transport.Writer.TryWrite(buffer)) - { - Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); - break; - } - } + + await _connectionContext.Transport.Output.WriteAsync(buffer); + + Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); } } } @@ -221,7 +234,6 @@ namespace Microsoft.AspNetCore.SignalR // If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the // true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval), // because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick. - Debug.Assert(_pingMessage != null, "Expected the ping message to be prepared before the first heartbeat tick"); if (Stopwatch.GetTimestamp() - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration) { @@ -229,7 +241,8 @@ namespace Microsoft.AspNetCore.SignalR // If the transport channel is full, this will fail, but that's OK because // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. - if (_connectionContext.Transport.Writer.TryWrite(_pingMessage)) + + if (Output.Writer.TryWrite(PingMessage.Instance)) { _logger.SentPing(); } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index fd2255c308..2ceacf21fa 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -171,58 +171,74 @@ namespace Microsoft.AspNetCore.SignalR try { - while (await connection.Input.WaitToReadAsync(connection.ConnectionAbortedToken)) + while (true) { - while (connection.Input.TryRead(out var buffer)) + var result = await connection.Input.ReadAsync(connection.ConnectionAbortedToken); + var buffer = result.Buffer; + var consumed = buffer.End; + var examined = buffer.End; + + try { - if (connection.ProtocolReaderWriter.ReadMessages(buffer, this, out var hubMessages)) + if (!buffer.IsEmpty) { - foreach (var hubMessage in hubMessages) + if (connection.ProtocolReaderWriter.ReadMessages(buffer, this, out var hubMessages, out consumed, out examined)) { - switch (hubMessage) + foreach (var hubMessage in hubMessages) { - case InvocationMessage invocationMessage: - _logger.ReceivedHubInvocation(invocationMessage); + switch (hubMessage) + { + case InvocationMessage invocationMessage: + _logger.ReceivedHubInvocation(invocationMessage); - // Don't wait on the result of execution, continue processing other - // incoming messages on this connection. - _ = ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false); - break; + // Don't wait on the result of execution, continue processing other + // incoming messages on this connection. + _ = ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false); + break; - case StreamInvocationMessage streamInvocationMessage: - _logger.ReceivedStreamHubInvocation(streamInvocationMessage); + case StreamInvocationMessage streamInvocationMessage: + _logger.ReceivedStreamHubInvocation(streamInvocationMessage); - // Don't wait on the result of execution, continue processing other - // incoming messages on this connection. - _ = ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true); - break; + // Don't wait on the result of execution, continue processing other + // incoming messages on this connection. + _ = ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true); + break; - case CancelInvocationMessage cancelInvocationMessage: - // Check if there is an associated active stream and cancel it if it exists. - // The cts will be removed when the streaming method completes executing - if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts)) - { - _logger.CancelStream(cancelInvocationMessage.InvocationId); - cts.Cancel(); - } - else - { - // Stream can be canceled on the server while client is canceling stream. - _logger.UnexpectedCancel(); - } - break; + case CancelInvocationMessage cancelInvocationMessage: + // Check if there is an associated active stream and cancel it if it exists. + // The cts will be removed when the streaming method completes executing + if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts)) + { + _logger.CancelStream(cancelInvocationMessage.InvocationId); + cts.Cancel(); + } + else + { + // Stream can be canceled on the server while client is canceling stream. + _logger.UnexpectedCancel(); + } + break; - case PingMessage _: - // We don't care about pings - break; + case PingMessage _: + // We don't care about pings + break; - // Other kind of message we weren't expecting - default: - _logger.UnsupportedMessageReceived(hubMessage.GetType().FullName); - throw new NotSupportedException($"Received unsupported message: {hubMessage}"); + // Other kind of message we weren't expecting + default: + _logger.UnsupportedMessageReceived(hubMessage.GetType().FullName); + throw new NotSupportedException($"Received unsupported message: {hubMessage}"); + } } } } + else if (result.IsCompleted) + { + break; + } + } + finally + { + connection.Input.AdvanceTo(consumed, examined); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs index 8f4c799a16..53779be11d 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.Threading.Channels; +using System.IO.Pipelines; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Sockets @@ -15,6 +15,6 @@ namespace Microsoft.AspNetCore.Sockets public abstract IDictionary Metadata { get; set; } - public abstract Channel Transport { get; set; } + public abstract IDuplexPipe Transport { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/DuplexPipe.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/DuplexPipe.cs new file mode 100644 index 0000000000..d751a1c8b4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/DuplexPipe.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace System.IO.Pipelines +{ + public class DuplexPipe : IDuplexPipe + { + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public void Dispose() + { + + } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public struct DuplexPipePair + { + public IDuplexPipe Transport { get; private set; } + public IDuplexPipe Application { get; private set; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs index e851b49bcc..3d2a412a4e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -1,13 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading.Channels; +using System.IO.Pipelines; namespace Microsoft.AspNetCore.Sockets.Features { public interface IConnectionTransportFeature { - Channel Transport { get; set; } + IDuplexPipe Transport { get; set; } TransferMode TransportCapabilities { get; set; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj index e1e43ecaf5..4ad4e73092 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj @@ -9,6 +9,9 @@ + + + diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index e4c66205e0..9ea7aaf9b5 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -4,6 +4,7 @@ using System; using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -116,7 +117,7 @@ namespace Microsoft.AspNetCore.Sockets connection.TransportCapabilities = TransferMode.Text; // We only need to provide the Input channel since writing to the application is handled through /send. - var sse = new ServerSentEventsTransport(connection.Application.Reader, connection.ConnectionId, _loggerFactory); + var sse = new ServerSentEventsTransport(connection.Application.Input, connection.ConnectionId, _loggerFactory); await DoPersistentConnection(socketDelegate, sse, context, connection); } @@ -218,7 +219,7 @@ namespace Microsoft.AspNetCore.Sockets context.Response.RegisterForDispose(timeoutSource); context.Response.RegisterForDispose(tokenSource); - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Reader, connection.ConnectionId, _loggerFactory); + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, connection.ConnectionId, _loggerFactory); // Start the transport connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); @@ -239,7 +240,7 @@ namespace Microsoft.AspNetCore.Sockets if (resultTask == connection.ApplicationTask) { // Complete the transport (notifying it of the application error if there is one) - connection.Transport.Writer.TryComplete(connection.ApplicationTask.Exception); + connection.Transport.Output.Complete(connection.ApplicationTask.Exception); // Wait for the transport to run await connection.TransportTask; @@ -440,13 +441,7 @@ namespace Microsoft.AspNetCore.Sockets } _logger.ReceivedBytes(buffer.Length); - while (!connection.Application.Writer.TryWrite(buffer)) - { - if (!await connection.Application.Writer.WaitToWriteAsync()) - { - return; - } - } + await connection.Application.Output.WriteAsync(buffer); } private async Task EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports, ConnectionLogScope logScope, HttpSocketOptions options) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/SocketHttpLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/SocketHttpLoggerExtensions.cs index 890816b7de..30ab3a793f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/SocketHttpLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/SocketHttpLoggerExtensions.cs @@ -16,8 +16,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal private static readonly Action _pollTimedOut = LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(PollTimedOut)), "Poll request timed out. Sending 200 response to connection."); - private static readonly Action _longPollingWritingMessage = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(LongPollingWritingMessage)), "Writing a {count} byte message to connection."); + private static readonly Action _longPollingWritingMessage = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(LongPollingWritingMessage)), "Writing a {count} byte message to connection."); private static readonly Action _longPollingDisconnected = LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(LongPollingDisconnected)), "Client disconnected from Long Polling endpoint for connection."); @@ -41,8 +41,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal private static readonly Action _resumingConnection = LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ResumingConnection)), "Resuming existing connection."); - private static readonly Action _receivedBytes = - LoggerMessage.Define(LogLevel.Debug, new EventId(6, nameof(ReceivedBytes)), "Received {count} bytes."); + private static readonly Action _receivedBytes = + LoggerMessage.Define(LogLevel.Debug, new EventId(6, nameof(ReceivedBytes)), "Received {count} bytes."); private static readonly Action _transportNotSupported = LoggerMessage.Define(LogLevel.Debug, new EventId(7, nameof(TransportNotSupported)), "{transportType} transport not supported by this endpoint type."); @@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal private static readonly Action _messageToApplication = LoggerMessage.Define(LogLevel.Debug, new EventId(10, nameof(MessageToApplication)), "Passing message to application. Payload size: {size}."); - private static readonly Action _sendPayload = - LoggerMessage.Define(LogLevel.Debug, new EventId(11, nameof(SendPayload)), "Sending payload: {size} bytes."); + private static readonly Action _sendPayload = + LoggerMessage.Define(LogLevel.Debug, new EventId(11, nameof(SendPayload)), "Sending payload: {size} bytes."); private static readonly Action _errorWritingFrame = LoggerMessage.Define(LogLevel.Error, new EventId(12, nameof(ErrorWritingFrame)), "Error writing frame."); @@ -97,8 +97,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal LoggerMessage.Define(LogLevel.Trace, new EventId(13, nameof(SendFailed)), "Socket failed to send."); // Category: ServerSentEventsTransport - private static readonly Action _sseWritingMessage = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(SSEWritingMessage)), "Writing a {count} byte message."); + private static readonly Action _sseWritingMessage = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(SSEWritingMessage)), "Writing a {count} byte message."); public static void LongPolling204(this ILogger logger) { @@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal _pollTimedOut(logger, null); } - public static void LongPollingWritingMessage(this ILogger logger, int count) + public static void LongPollingWritingMessage(this ILogger logger, long count) { _longPollingWritingMessage(logger, count, null); } @@ -150,7 +150,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal _resumingConnection(logger, null); } - public static void ReceivedBytes(this ILogger logger, int count) + public static void ReceivedBytes(this ILogger logger, long count) { _receivedBytes(logger, count, null); } @@ -225,7 +225,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal _messageToApplication(logger, size, null); } - public static void SendPayload(this ILogger logger, int size) + public static void SendPayload(this ILogger logger, long size) { _sendPayload(logger, size, null); } @@ -240,7 +240,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal _sendFailed(logger, ex); } - public static void SSEWritingMessage(this ILogger logger, int count) + public static void SSEWritingMessage(this ILogger logger, long count) { _sseWritingMessage(logger, count, null); } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs index 893e596d78..b6003a37a0 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs @@ -2,24 +2,24 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Runtime.InteropServices; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Internal.Transports { public class LongPollingTransport : IHttpTransport { - private readonly ChannelReader _application; + private readonly PipeReader _application; private readonly ILogger _logger; private readonly CancellationToken _timeoutToken; private readonly string _connectionId; - public LongPollingTransport(CancellationToken timeoutToken, ChannelReader application, string connectionId, ILoggerFactory loggerFactory) + public LongPollingTransport(CancellationToken timeoutToken, PipeReader application, string connectionId, ILoggerFactory loggerFactory) { _timeoutToken = timeoutToken; _application = application; @@ -31,33 +31,38 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { try { - if (!await _application.WaitToReadAsync(token)) + var result = await _application.ReadAsync(token); + var buffer = result.Buffer; + + if (buffer.IsEmpty && result.IsCompleted) { - await _application.Completion; _logger.LongPolling204(); context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status204NoContent; return; } - var contentLength = 0; - var buffers = new List(); // We're intentionally not checking cancellation here because we need to drain messages we've got so far, // but it's too late to emit the 204 required by being cancelled. - while (_application.TryRead(out var buffer)) - { - contentLength += buffer.Length; - buffers.Add(buffer); - _logger.LongPollingWritingMessage(buffer.Length); - } + _logger.LongPollingWritingMessage(buffer.Length); - context.Response.ContentLength = contentLength; + context.Response.ContentLength = buffer.Length; context.Response.ContentType = "application/octet-stream"; - foreach (var buffer in buffers) + try { - await context.Response.Body.WriteAsync(buffer, 0, buffer.Length); + foreach (var segment in buffer) + { + var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment); + // We're using the managed memory pool which is backed by managed buffers + Debug.Assert(isArray); + await context.Response.Body.WriteAsync(arraySegment.Array, 0, arraySegment.Count); + } + } + finally + { + _application.AdvanceTo(buffer.End); } } catch (OperationCanceledException) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs index 2596446400..2526095b07 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs @@ -3,9 +3,9 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; -using System.Threading.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Internal.Formatters; @@ -15,11 +15,11 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { public class ServerSentEventsTransport : IHttpTransport { - private readonly ChannelReader _application; + private readonly PipeReader _application; private readonly string _connectionId; private readonly ILogger _logger; - public ServerSentEventsTransport(ChannelReader application, string connectionId, ILoggerFactory loggerFactory) + public ServerSentEventsTransport(PipeReader application, string connectionId, ILoggerFactory loggerFactory) { _application = application; _connectionId = connectionId; @@ -44,21 +44,32 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports try { - while (await _application.WaitToReadAsync(token)) + while (true) { - var ms = new MemoryStream(); - while (_application.TryRead(out var buffer)) + var result = await _application.ReadAsync(token); + var buffer = result.Buffer; + + try { - _logger.SSEWritingMessage(buffer.Length); - - ServerSentEventsMessageFormatter.WriteMessage(buffer, ms); + if (!buffer.IsEmpty) + { + var ms = new MemoryStream(); + _logger.SSEWritingMessage(buffer.Length); + // Don't create a copy using ToArray every time + ServerSentEventsMessageFormatter.WriteMessage(buffer.ToArray(), ms); + ms.Seek(0, SeekOrigin.Begin); + await ms.CopyToAsync(context.Response.Body); + } + else if (result.IsCompleted) + { + break; + } + } + finally + { + _application.AdvanceTo(buffer.End); } - - ms.Seek(0, SeekOrigin.Begin); - await ms.CopyToAsync(context.Response.Body); } - - await _application.Completion; } catch (OperationCanceledException) { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index e3ec5aa6f0..ed90a9e099 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -4,9 +4,9 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO.Pipelines; using System.Net.WebSockets; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -17,10 +17,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { private readonly WebSocketOptions _options; private readonly ILogger _logger; - private readonly Channel _application; + private readonly IDuplexPipe _application; private readonly DefaultConnectionContext _connection; - public WebSocketsTransport(WebSocketOptions options, Channel application, DefaultConnectionContext connection, ILoggerFactory loggerFactory) + public WebSocketsTransport(WebSocketOptions options, IDuplexPipe application, DefaultConnectionContext connection, ILoggerFactory loggerFactory) { if (options == null) { @@ -86,9 +86,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports _logger.WaitingForClose(); } - // We're done writing - _application.Writer.TryComplete(); - await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); var resultTask = await Task.WhenAny(task, Task.Delay(_options.CloseTimeout)); @@ -110,20 +107,17 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports private async Task StartReceiving(WebSocket socket) { - // REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially - // for server logic) - var incomingMessage = new List>(); - while (true) + try { - const int bufferSize = 4096; - var totalBytes = 0; - WebSocketReceiveResult receiveResult; - do + while (true) { - var buffer = new ArraySegment(new byte[bufferSize]); + var memory = _application.Output.GetMemory(); + + // REVIEW: Use new Memory websocket APIs on .NET Core 2.1 + memory.TryGetArray(out var arraySegment); // Exceptions are handled above where the send and receive tasks are being run. - receiveResult = await socket.ReceiveAsync(buffer, CancellationToken.None); + var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None); if (receiveResult.MessageType == WebSocketMessageType.Close) { return receiveResult; @@ -131,54 +125,33 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports _logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); - incomingMessage.Add(truncBuffer); - totalBytes += receiveResult.Count; - } while (!receiveResult.EndOfMessage); + _application.Output.Advance(receiveResult.Count); - // Making sure the message type is either text or binary - Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type"); - - // TODO: Check received message type against the _options.WebSocketMessageType - - byte[] messageBuffer = null; - - if (incomingMessage.Count > 1) - { - messageBuffer = new byte[totalBytes]; - var offset = 0; - for (var i = 0; i < incomingMessage.Count; i++) + if (receiveResult.EndOfMessage) { - Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count); - offset += incomingMessage[i].Count; - } - } - else - { - messageBuffer = new byte[incomingMessage[0].Count]; - Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count); - } - - _logger.MessageToApplication(messageBuffer.Length); - while (await _application.Writer.WaitToWriteAsync()) - { - if (_application.Writer.TryWrite(messageBuffer)) - { - incomingMessage.Clear(); - break; + await _application.Output.FlushAsync(); } } } + finally + { + // We're done writing + _application.Output.Complete(); + } } private async Task StartSending(WebSocket ws) { - while (await _application.Reader.WaitToReadAsync()) + while (true) { + var result = await _application.Input.ReadAsync(); + var buffer = result.Buffer; + // Get a frame from the application - while (_application.Reader.TryRead(out var buffer)) + + try { - if (buffer.Length > 0) + if (!buffer.IsEmpty) { try { @@ -190,7 +163,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports if (WebSocketCanSend(ws)) { - await ws.SendAsync(new ArraySegment(buffer), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None); + await ws.SendAsync(new ArraySegment(buffer.ToArray()), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None); } } catch (WebSocketException socketException) when (!WebSocketCanSend(ws)) @@ -205,6 +178,14 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports break; } } + else if (result.IsCompleted) + { + break; + } + } + finally + { + _application.Input.AdvanceTo(buffer.End); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 55bb38b3e7..832cf3c311 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -6,6 +6,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Net.WebSockets; using System.Threading; using System.Threading.Channels; @@ -63,13 +64,9 @@ namespace Microsoft.AspNetCore.Sockets _logger.CreatedNewConnection(id); var connectionTimer = SocketEventSource.Log.ConnectionStart(id); - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication); - var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport); - - var connection = new DefaultConnectionContext(id, applicationSide, transportSide); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + + var connection = new DefaultConnectionContext(id, pair.Application, pair.Transport); connection.ConnectionTimer = connectionTimer; _connections.TryAdd(id, connection); diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index 6faa0b0b26..eb7b727673 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -3,9 +3,9 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Security.Claims; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Features; @@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); internal ValueStopwatch ConnectionTimer { get; set; } - public DefaultConnectionContext(string id, Channel transport, Channel application) + public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) { Transport = transport; Application = application; @@ -65,9 +65,9 @@ namespace Microsoft.AspNetCore.Sockets public override IDictionary Metadata { get; set; } = new ConnectionMetadata(); - public Channel Application { get; } + public IDuplexPipe Application { get; } - public override Channel Transport { get; set; } + public override IDuplexPipe Transport { get; set; } public TransferMode TransportCapabilities { get; set; } @@ -111,21 +111,21 @@ namespace Microsoft.AspNetCore.Sockets // If the application task is faulted, propagate the error to the transport if (ApplicationTask?.IsFaulted == true) { - Transport.Writer.TryComplete(ApplicationTask.Exception.InnerException); + Transport.Output.Complete(ApplicationTask.Exception.InnerException); } else { - Transport.Writer.TryComplete(); + Transport.Output.Complete(); } // If the transport task is faulted, propagate the error to the application if (TransportTask?.IsFaulted == true) { - Application.Writer.TryComplete(TransportTask.Exception.InnerException); + Application.Output.Complete(TransportTask.Exception.InnerException); } else { - Application.Writer.TryComplete(); + Application.Output.Complete(); } var applicationTask = ApplicationTask ?? Task.CompletedTask; @@ -139,7 +139,18 @@ namespace Microsoft.AspNetCore.Sockets Lock.Release(); } - await disposeTask; + try + { + await disposeTask; + } + finally + { + // REVIEW: Should we move this to the read loops? + + // Complete the reading side of the pipes + Application.Input.Complete(); + Transport.Input.Complete(); + } } private async Task WaitOnTasks(Task applicationTask, Task transportTask) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index fe9daf6b21..f602324733 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -2,18 +2,17 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Internal; -using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Tests { @@ -23,22 +22,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests private readonly HubProtocolReaderWriter _protocolReaderWriter; private readonly IInvocationBinder _invocationBinder; private CancellationTokenSource _cts; - private ChannelConnection _transport; + private Queue _messages = new Queue(); public DefaultConnectionContext Connection { get; } - public Channel Application { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { - var options = new UnboundedChannelOptions { AllowSynchronousContinuations = synchronousCallbacks }; - var transportToApplication = Channel.CreateUnbounded(options); - var applicationToTransport = Channel.CreateUnbounded(options); - - Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); - _transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - - Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application); + var options = new PipeOptions(readerScheduler: synchronousCallbacks ? PipeScheduler.Inline : null); + var pair = DuplexPipe.CreateConnectionPair(options, options); + Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); var claimValue = Interlocked.Increment(ref _id).ToString(); var claims = new List { new Claim(ClaimTypes.Name, claimValue) }; @@ -59,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var memoryStream = new MemoryStream()) { NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream); - Application.Writer.TryWrite(memoryStream.ToArray()); + Connection.Application.Output.WriteAsync(memoryStream.ToArray()); } } @@ -151,7 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task SendHubMessageAsync(HubMessage message) { var payload = _protocolReaderWriter.WriteMessage(message); - await Application.Writer.WriteAsync(payload); + await Connection.Application.Output.WriteAsync(payload); return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null; } @@ -163,9 +156,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests if (message == null) { - if (!await Application.Reader.WaitToReadAsync()) + var result = await Connection.Application.Input.ReadAsync(); + var buffer = result.Buffer; + + try { - return null; + if (!buffer.IsEmpty) + { + continue; + } + + if (result.IsCompleted) + { + return null; + } + } + finally + { + Connection.Application.Input.AdvanceTo(buffer.Start); } } else @@ -177,18 +185,45 @@ namespace Microsoft.AspNetCore.SignalR.Tests public HubMessage TryRead() { - if (Application.Reader.TryRead(out var buffer) && - _protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages)) + if (_messages.Count > 0) { - return messages[0]; + return _messages.Dequeue(); } + + if (!Connection.Application.Input.TryRead(out var result)) + { + return null; + } + + var buffer = result.Buffer; + var consumed = buffer.End; + var examined = consumed; + + try + { + if (_protocolReaderWriter.ReadMessages(result.Buffer, _invocationBinder, out var messages, out consumed, out examined)) + { + foreach (var m in messages) + { + _messages.Enqueue(m); + } + + return _messages.Dequeue(); + } + } + finally + { + Connection.Application.Input.AdvanceTo(consumed, examined); + } + return null; } public void Dispose() { _cts.Cancel(); - _transport.Dispose(); + + Connection.Application.Output.Complete(); } private static string GetInvocationId() diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs index 034230b956..839e98088b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs @@ -1,8 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.IO.Pipelines; using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Tests @@ -11,7 +11,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public async override Task OnConnectedAsync(ConnectionContext connection) { - await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync()); + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + + try + { + if (!buffer.IsEmpty) + { + await connection.Transport.Output.WriteAsync(buffer.ToArray()); + } + } + finally + { + connection.Transport.Input.AdvanceTo(buffer.End); + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index eab73d564b..80d7ad05fd 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -92,6 +92,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(bytes, buffer.Array.AsSpan().Slice(0, result.Count).ToArray()); + logger.LogInformation("Waiting for close"); + result = await ws.ReceiveAsync(buffer, CancellationToken.None).OrTimeout(); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); logger.LogInformation("Closing socket"); await ws.CloseAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None).OrTimeout(); logger.LogInformation("Closed socket"); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 8f65275984..2e84b84ca7 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -234,13 +234,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests await waitForSubscribe.Task.OrTimeout(); - observable.OnNext(1); - await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout(); await waitForDispose.Task.OrTimeout(); - Assert.Equal(1L, ((StreamItemMessage)await client.ReadAsync().OrTimeout()).Item); + var message = await client.ReadAsync().OrTimeout(); + + Assert.IsType(message); client.Dispose(); @@ -257,7 +257,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent - client.Connection.Transport.Reader.TryRead(out var item); + client.Connection.Transport.Input.TryRead(out var item); + client.Connection.Transport.Input.AdvanceTo(item.Buffer.End); var endPointTask = endPoint.OnConnectedAsync(client.Connection); @@ -283,7 +284,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { // TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent - client.Connection.Transport.Reader.TryRead(out var item); + client.Connection.Transport.Input.TryRead(out var item); + client.Connection.Transport.Input.AdvanceTo(item.Buffer.End); await endPoint.OnConnectedAsync(client.Connection).OrTimeout(); } @@ -567,7 +569,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); // Nothing should have been written - Assert.False(client.Application.Reader.TryRead(out var buffer)); + Assert.False(client.Connection.Application.Input.TryRead(out var buffer)); await endPointTask.OrTimeout(); } @@ -1720,6 +1722,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests await endPointLifetime.OrTimeout(); + client.Connection.Transport.Output.Complete(); + // We shouldn't have any ping messages HubMessage message; var counter = 0; @@ -1760,6 +1764,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests await endPointLifetime.OrTimeout(); + client.Connection.Transport.Output.Complete(); + // We should have all pings HubMessage message; var counter = 0; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index dc631a2593..742e4c64ac 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -81,12 +81,29 @@ namespace Microsoft.AspNetCore.Sockets.Tests connection.ApplicationTask = Task.Run(async () => { - Assert.False(await connection.Transport.Reader.WaitToReadAsync()); + var result = await connection.Transport.Input.ReadAsync(); + + try + { + Assert.True(result.IsCompleted); + } + finally + { + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } }); connection.TransportTask = Task.Run(async () => { - Assert.False(await connection.Application.Reader.WaitToReadAsync()); + var result = await connection.Application.Input.ReadAsync(); + try + { + Assert.True(result.IsCompleted); + } + finally + { + connection.Application.Input.AdvanceTo(result.Buffer.End); + } }); connectionManager.CloseConnections(); @@ -188,15 +205,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests { var appLifetime = new TestApplicationLifetime(); var connectionManager = CreateConnectionManager(appLifetime); + var tcs = new TaskCompletionSource(); appLifetime.Start(); var connection = connectionManager.CreateConnection(); + connection.Application.Output.OnReaderCompleted((error, state) => + { + tcs.TrySetResult(null); + }, + null); + appLifetime.StopApplication(); // Connection should be disposed so this should complete immediately - Assert.False(await connection.Application.Writer.WaitToWriteAsync().OrTimeout()); + await tcs.Task.OrTimeout(); } private static ConnectionManager CreateConnectionManager(IApplicationLifetime lifetime = null) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index ec42fccf51..aac60adee8 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Net.WebSockets; using System.Security.Claims; using System.Text; @@ -199,6 +200,52 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + [Theory] + [InlineData(TransportType.LongPolling)] + [InlineData(TransportType.ServerSentEvents)] + public async Task PostSendsToConnection(TransportType transportType) + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.Metadata[ConnectionMetadataNames.Transport] = transportType; + + using (var requestBody = new MemoryStream()) + using (var responseBody = new MemoryStream()) + { + var bytes = Encoding.UTF8.GetBytes("Hello World"); + requestBody.Write(bytes, 0, bytes.Length); + requestBody.Seek(0, SeekOrigin.Begin); + + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + context.Response.Body = responseBody; + + var services = new ServiceCollection(); + services.AddEndPoint(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); + + Assert.True(connection.Transport.Input.TryRead(out var result)); + Assert.Equal("Hello World", Encoding.UTF8.GetString(result.Buffer.ToArray())); + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } + } + } + [Theory] [InlineData(TransportType.ServerSentEvents)] [InlineData(TransportType.LongPolling)] @@ -570,7 +617,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the transport so the poll yields - await connection.Transport.Writer.WriteAsync(buffer); + await connection.Transport.Output.WriteAsync(buffer); await task; @@ -605,7 +652,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await connection.Application.Writer.WriteAsync(buffer); + await connection.Application.Output.WriteAsync(buffer); await task; @@ -638,7 +685,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await connection.Application.Writer.WriteAsync(buffer); + await connection.Application.Output.WriteAsync(buffer); await task; @@ -674,7 +721,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await task1.OrTimeout(); // Send a message from the app to complete Task 2 - await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")); await task2.OrTimeout(); @@ -855,7 +902,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -936,7 +983,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -993,7 +1040,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout(); await endPointTask.OrTimeout(); @@ -1229,7 +1276,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override Task OnConnectedAsync(ConnectionContext connection) { - connection.Transport.Reader.WaitToReadAsync().Wait(); + var waitHandle = new ManualResetEventSlim(); + var awaiter = connection.Transport.Input.ReadAsync(); + awaiter.OnCompleted(waitHandle.Set); + waitHandle.Wait(); + + var result = awaiter.GetResult(); + connection.Transport.Input.AdvanceTo(result.Buffer.End); + return Task.CompletedTask; } } @@ -1254,8 +1308,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override async Task OnConnectedAsync(ConnectionContext connection) { - while (await connection.Transport.Reader.WaitToReadAsync()) + while (true) { + var result = await connection.Transport.Input.ReadAsync(); + + try + { + if (result.IsCompleted) + { + break; + } + } + finally + { + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index 3e88995ef5..c62d55a357 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -3,12 +3,11 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Text; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; using Xunit; @@ -20,14 +19,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task Set204StatusCodeWhenChannelComplete() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(toTransport.Writer.TryComplete()); + connection.Transport.Output.Complete(); await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); @@ -37,13 +35,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task Set200StatusCodeWhenTimeoutTokenFires() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); var timeoutToken = new CancellationToken(true); - var poll = new LongPollingTransport(timeoutToken, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(timeoutToken, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, context.RequestAborted)) { @@ -57,18 +54,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task FrameSentAsSingleResponse() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); - - Assert.True(toTransport.Writer.TryComplete()); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); + connection.Transport.Output.Complete(); await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); @@ -79,20 +74,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task MultipleFramesSentAsSingleResponse() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); - await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(" ")); - await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("World")); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes(" ")); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("World")); - Assert.True(toTransport.Writer.TryComplete()); + connection.Transport.Output.Complete(); await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs index 1b5b55ccb1..1ab370512a 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs @@ -92,9 +92,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override async Task OnConnectedAsync(ConnectionContext connection) { - while (!await connection.Transport.Reader.WaitToReadAsync()) + while (true) { + var result = await connection.Transport.Input.ReadAsync(); + if (result.IsCompleted) + { + break; + } + + // Consume nothing + connection.Transport.Input.AdvanceTo(result.Buffer.Start); } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index cf60e1b028..426d354c09 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -2,8 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; +using System.IO.Pipelines; using System.Text; -using System.Threading.Channels; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -18,14 +19,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSESetsContentType() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(toTransport.Writer.TryComplete()); + connection.Transport.Output.Complete(); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -36,16 +36,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSETurnsResponseBufferingOff() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); var feature = new HttpBufferingFeature(); context.Features.Set(feature); - var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: connection.ConnectionId, loggerFactory: new LoggerFactory()); - Assert.True(toTransport.Writer.TryComplete()); + connection.Transport.Output.Complete(); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -55,25 +54,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSEWritesMessages() { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(new UnboundedChannelOptions - { - AllowSynchronousContinuations = true - }); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, new PipeOptions(readerScheduler: PipeScheduler.Inline)); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + var ms = new MemoryStream(); context.Response.Body = ms; - var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var task = sse.ProcessRequestAsync(context, context.RequestAborted); - await toTransport.Writer.WriteAsync(Encoding.ASCII.GetBytes("Hello")); + await connection.Transport.Output.WriteAsync(Encoding.ASCII.GetBytes("Hello")); Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); - toTransport.Writer.TryComplete(); + connection.Transport.Output.Complete(); await task.OrTimeout(); } @@ -84,18 +80,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData("Hello\r\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")] public async Task SSEAddsAppropriateFraming(string message, string expected) { - var toApplication = Channel.CreateUnbounded(); - var toTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(message)); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes(message)); - Assert.True(toTransport.Writer.TryComplete()); + connection.Transport.Output.Complete(); await sse.ProcessRequestAsync(context, context.RequestAborted); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 2b7af8d770..ded1a6bd09 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -2,12 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net.WebSockets; using System.Text; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -30,15 +29,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -54,10 +51,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests cancellationToken: CancellationToken.None); await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - var buffer = await applicationSide.Reader.ReadAsync(); - Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.ToArray())); + connection.Transport.Input.AdvanceTo(buffer.End); - Assert.True(applicationSide.Writer.TryComplete()); + connection.Transport.Output.Complete(); // The transport should finish now await transport; @@ -77,15 +76,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode }; - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -94,8 +91,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = feature.Client.ExecuteAndCaptureFramesAsync(); // Write to the output channel, and then complete it - await applicationSide.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); - Assert.True(applicationSide.Writer.TryComplete()); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + connection.Transport.Output.Complete(); // The client should finish now, as should the server var clientSummary = await client; @@ -115,24 +112,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { async Task CompleteApplicationAfterTransportCompletes() { // Wait until the transport completes so that we can end the application - await applicationSide.Reader.WaitToReadAsync(); + var result = await connection.Transport.Input.ReadAsync(); + connection.Transport.Input.AdvanceTo(result.Buffer.End); // Complete the application so that the connection unwinds without aborting - applicationSide.Writer.TryComplete(); + connection.Transport.Output.Complete(); } var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -150,8 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests // Wait for the transport await Assert.ThrowsAsync(() => transport).OrTimeout(); - var summary = await client.OrTimeout(); - Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus); + await client.OrTimeout(); } } } @@ -161,15 +156,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -178,7 +171,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = feature.Client.ExecuteAndCaptureFramesAsync(); // Fail in the app - Assert.True(applicationSide.Writer.TryComplete(new InvalidOperationException("Catastrophic failure."))); + connection.Transport.Output.Complete(new InvalidOperationException("Catastrophic failure.")); var clientSummary = await client.OrTimeout(); Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus); @@ -196,11 +189,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var options = new WebSocketOptions() @@ -209,14 +200,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests }; var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(serverSocket); // End the app - applicationSide.Dispose(); + connection.Transport.Output.Complete(); await transport.OrTimeout(TimeSpan.FromSeconds(10)); @@ -233,11 +224,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var options = new WebSocketOptions @@ -246,7 +235,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }; var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -256,7 +245,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = feature.Client.ExecuteAndCaptureFramesAsync(); // fail the client to server channel - applicationToTransport.Writer.TryComplete(new Exception()); + connection.Transport.Output.Complete(new Exception()); await Assert.ThrowsAsync(() => transport).OrTimeout(); @@ -270,11 +259,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var options = new WebSocketOptions @@ -284,7 +271,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }; var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -294,7 +281,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = feature.Client.ExecuteAndCaptureFramesAsync(); // close the client to server channel - applicationToTransport.Writer.TryComplete(); + connection.Transport.Output.Complete(); _ = await client.OrTimeout(); @@ -312,11 +299,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests { using (StartLog(out var loggerFactory, LogLevel.Debug)) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); - using (var transportSide = ChannelConnection.Create(applicationToTransport, transportToApplication)) - using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var options = new WebSocketOptions @@ -325,7 +310,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests CloseTimeout = TimeSpan.FromSeconds(20) }; var connectionContext = new DefaultConnectionContext(string.Empty, null, null); - var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory); + var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -337,7 +322,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); // close the client to server channel - applicationToTransport.Writer.TryComplete(); + connection.Transport.Output.Complete(); _ = await client.OrTimeout();